Compare commits
No commits in common. "db465f1bf05c747e065ad84ee7daa37547af442c" and "aeeb7bb7f79ccc5a1380b47a54e4b33559381101" have entirely different histories.
db465f1bf0
...
aeeb7bb7f7
@ -3,7 +3,6 @@ package cmd
|
|||||||
import (
|
import (
|
||||||
"fmt"
|
"fmt"
|
||||||
|
|
||||||
cmdutil "git.mlow.ca/mlow/lmcli/pkg/cmd/util"
|
|
||||||
"git.mlow.ca/mlow/lmcli/pkg/lmcli"
|
"git.mlow.ca/mlow/lmcli/pkg/lmcli"
|
||||||
"git.mlow.ca/mlow/lmcli/pkg/tui"
|
"git.mlow.ca/mlow/lmcli/pkg/tui"
|
||||||
"github.com/spf13/cobra"
|
"github.com/spf13/cobra"
|
||||||
@ -15,16 +14,11 @@ func ChatCmd(ctx *lmcli.Context) *cobra.Command {
|
|||||||
Short: "Open the chat interface",
|
Short: "Open the chat interface",
|
||||||
Long: `Open the chat interface, optionally on a given conversation.`,
|
Long: `Open the chat interface, optionally on a given conversation.`,
|
||||||
RunE: func(cmd *cobra.Command, args []string) error {
|
RunE: func(cmd *cobra.Command, args []string) error {
|
||||||
|
// TODO: implement jump-to-conversation logic
|
||||||
shortname := ""
|
shortname := ""
|
||||||
if len(args) == 1 {
|
if len(args) == 1 {
|
||||||
shortname = args[0]
|
shortname = args[0]
|
||||||
}
|
}
|
||||||
if shortname != ""{
|
|
||||||
_, err := cmdutil.LookupConversationE(ctx, shortname)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
}
|
|
||||||
err := tui.Launch(ctx, shortname)
|
err := tui.Launch(ctx, shortname)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("Error fetching LLM response: %v", err)
|
return fmt.Errorf("Error fetching LLM response: %v", err)
|
||||||
|
@ -5,6 +5,7 @@ import (
|
|||||||
|
|
||||||
cmdutil "git.mlow.ca/mlow/lmcli/pkg/cmd/util"
|
cmdutil "git.mlow.ca/mlow/lmcli/pkg/cmd/util"
|
||||||
"git.mlow.ca/mlow/lmcli/pkg/lmcli"
|
"git.mlow.ca/mlow/lmcli/pkg/lmcli"
|
||||||
|
"git.mlow.ca/mlow/lmcli/pkg/lmcli/model"
|
||||||
"github.com/spf13/cobra"
|
"github.com/spf13/cobra"
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -27,12 +28,36 @@ func CloneCmd(ctx *lmcli.Context) *cobra.Command {
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
clone, messageCnt, err := ctx.Store.CloneConversation(*toClone)
|
messagesToCopy, err := ctx.Store.Messages(toClone)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("Failed to clone conversation: %v", err)
|
return fmt.Errorf("Could not retrieve messages for conversation: %s", toClone.ShortName.String)
|
||||||
}
|
}
|
||||||
|
|
||||||
fmt.Printf("Cloned %d messages to: %s - %s\n", messageCnt, clone.ShortName.String, clone.Title)
|
clone := &model.Conversation{
|
||||||
|
Title: toClone.Title + " - Clone",
|
||||||
|
}
|
||||||
|
if err := ctx.Store.SaveConversation(clone); err != nil {
|
||||||
|
return fmt.Errorf("Cloud not create clone: %s", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
var errors []error
|
||||||
|
messageCnt := 0
|
||||||
|
for _, message := range messagesToCopy {
|
||||||
|
newMessage := message
|
||||||
|
newMessage.ConversationID = clone.ID
|
||||||
|
newMessage.ID = 0
|
||||||
|
if err := ctx.Store.SaveMessage(&newMessage); err != nil {
|
||||||
|
errors = append(errors, err)
|
||||||
|
} else {
|
||||||
|
messageCnt++
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(errors) > 0 {
|
||||||
|
return fmt.Errorf("Messages failed to be cloned: %v", errors)
|
||||||
|
}
|
||||||
|
|
||||||
|
fmt.Printf("Cloned %d messages to: %s\n", messageCnt, clone.Title)
|
||||||
return nil
|
return nil
|
||||||
},
|
},
|
||||||
ValidArgsFunction: func(cmd *cobra.Command, args []string, toComplete string) ([]string, cobra.ShellCompDirective) {
|
ValidArgsFunction: func(cmd *cobra.Command, args []string, toComplete string) ([]string, cobra.ShellCompDirective) {
|
||||||
|
@ -26,7 +26,7 @@ func ContinueCmd(ctx *lmcli.Context) *cobra.Command {
|
|||||||
shortName := args[0]
|
shortName := args[0]
|
||||||
conversation := cmdutil.LookupConversation(ctx, shortName)
|
conversation := cmdutil.LookupConversation(ctx, shortName)
|
||||||
|
|
||||||
messages, err := ctx.Store.PathToLeaf(conversation.SelectedRoot)
|
messages, err := ctx.Store.Messages(conversation)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("could not retrieve conversation messages: %v", err)
|
return fmt.Errorf("could not retrieve conversation messages: %v", err)
|
||||||
}
|
}
|
||||||
|
@ -24,7 +24,7 @@ func EditCmd(ctx *lmcli.Context) *cobra.Command {
|
|||||||
shortName := args[0]
|
shortName := args[0]
|
||||||
conversation := cmdutil.LookupConversation(ctx, shortName)
|
conversation := cmdutil.LookupConversation(ctx, shortName)
|
||||||
|
|
||||||
messages, err := ctx.Store.PathToLeaf(conversation.SelectedRoot)
|
messages, err := ctx.Store.Messages(conversation)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("Could not retrieve messages for conversation: %s", conversation.Title)
|
return fmt.Errorf("Could not retrieve messages for conversation: %s", conversation.Title)
|
||||||
}
|
}
|
||||||
@ -39,7 +39,21 @@ func EditCmd(ctx *lmcli.Context) *cobra.Command {
|
|||||||
}
|
}
|
||||||
|
|
||||||
desiredIdx := len(messages) - 1 - offset
|
desiredIdx := len(messages) - 1 - offset
|
||||||
toEdit := messages[desiredIdx]
|
|
||||||
|
// walk backwards through the conversation deleting messages until and
|
||||||
|
// including the last user message
|
||||||
|
toRemove := []model.Message{}
|
||||||
|
var toEdit *model.Message
|
||||||
|
for i := len(messages) - 1; i >= 0; i-- {
|
||||||
|
if i == desiredIdx {
|
||||||
|
toEdit = &messages[i]
|
||||||
|
}
|
||||||
|
toRemove = append(toRemove, messages[i])
|
||||||
|
messages = messages[:i]
|
||||||
|
if toEdit != nil {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
newContents := inputFromArgsOrEditor(args[1:], "# Save when finished editing\n", toEdit.Content)
|
newContents := inputFromArgsOrEditor(args[1:], "# Save when finished editing\n", toEdit.Content)
|
||||||
switch newContents {
|
switch newContents {
|
||||||
@ -49,17 +63,26 @@ func EditCmd(ctx *lmcli.Context) *cobra.Command {
|
|||||||
return fmt.Errorf("No message was provided.")
|
return fmt.Errorf("No message was provided.")
|
||||||
}
|
}
|
||||||
|
|
||||||
toEdit.Content = newContents
|
|
||||||
|
|
||||||
role, _ := cmd.Flags().GetString("role")
|
role, _ := cmd.Flags().GetString("role")
|
||||||
if role != "" {
|
if role == "" {
|
||||||
if role != string(model.MessageRoleUser) && role != string(model.MessageRoleAssistant) {
|
role = string(toEdit.Role)
|
||||||
return fmt.Errorf("Invalid role specified. Please use 'user' or 'assistant'.")
|
} else if role != string(model.MessageRoleUser) && role != string(model.MessageRoleAssistant) {
|
||||||
}
|
return fmt.Errorf("Invalid role specified. Please use 'user' or 'assistant'.")
|
||||||
toEdit.Role = model.MessageRole(role)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
return ctx.Store.UpdateMessage(&toEdit)
|
for _, message := range toRemove {
|
||||||
|
err = ctx.Store.DeleteMessage(&message)
|
||||||
|
if err != nil {
|
||||||
|
lmcli.Warn("Could not delete message: %v\n", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
cmdutil.HandleConversationReply(ctx, conversation, true, model.Message{
|
||||||
|
ConversationID: conversation.ID,
|
||||||
|
Role: model.MessageRole(role),
|
||||||
|
Content: newContents,
|
||||||
|
})
|
||||||
|
return nil
|
||||||
},
|
},
|
||||||
ValidArgsFunction: func(cmd *cobra.Command, args []string, toComplete string) ([]string, cobra.ShellCompDirective) {
|
ValidArgsFunction: func(cmd *cobra.Command, args []string, toComplete string) ([]string, cobra.ShellCompDirective) {
|
||||||
compMode := cobra.ShellCompDirectiveNoFileComp
|
compMode := cobra.ShellCompDirectiveNoFileComp
|
||||||
|
@ -2,6 +2,7 @@ package cmd
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"slices"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"git.mlow.ca/mlow/lmcli/pkg/lmcli"
|
"git.mlow.ca/mlow/lmcli/pkg/lmcli"
|
||||||
@ -20,7 +21,7 @@ func ListCmd(ctx *lmcli.Context) *cobra.Command {
|
|||||||
Short: "List conversations",
|
Short: "List conversations",
|
||||||
Long: `List conversations in order of recent activity`,
|
Long: `List conversations in order of recent activity`,
|
||||||
RunE: func(cmd *cobra.Command, args []string) error {
|
RunE: func(cmd *cobra.Command, args []string) error {
|
||||||
messages, err := ctx.Store.LatestConversationMessages()
|
conversations, err := ctx.Store.Conversations()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("Could not fetch conversations: %v", err)
|
return fmt.Errorf("Could not fetch conversations: %v", err)
|
||||||
}
|
}
|
||||||
@ -57,8 +58,13 @@ func ListCmd(ctx *lmcli.Context) *cobra.Command {
|
|||||||
|
|
||||||
all, _ := cmd.Flags().GetBool("all")
|
all, _ := cmd.Flags().GetBool("all")
|
||||||
|
|
||||||
for _, message := range messages {
|
for _, conversation := range conversations {
|
||||||
messageAge := now.Sub(message.CreatedAt)
|
lastMessage, err := ctx.Store.LastMessage(&conversation)
|
||||||
|
if lastMessage == nil || err != nil {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
messageAge := now.Sub(lastMessage.CreatedAt)
|
||||||
|
|
||||||
var category string
|
var category string
|
||||||
for _, c := range categories {
|
for _, c := range categories {
|
||||||
@ -70,9 +76,9 @@ func ListCmd(ctx *lmcli.Context) *cobra.Command {
|
|||||||
|
|
||||||
formatted := fmt.Sprintf(
|
formatted := fmt.Sprintf(
|
||||||
"%s - %s - %s",
|
"%s - %s - %s",
|
||||||
message.Conversation.ShortName.String,
|
conversation.ShortName.String,
|
||||||
util.HumanTimeElapsedSince(messageAge),
|
util.HumanTimeElapsedSince(messageAge),
|
||||||
message.Conversation.Title,
|
conversation.Title,
|
||||||
)
|
)
|
||||||
|
|
||||||
categorized[category] = append(
|
categorized[category] = append(
|
||||||
@ -90,10 +96,14 @@ func ListCmd(ctx *lmcli.Context) *cobra.Command {
|
|||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
|
slices.SortFunc(conversationLines, func(a, b ConversationLine) int {
|
||||||
|
return int(a.timeSinceReply - b.timeSinceReply)
|
||||||
|
})
|
||||||
|
|
||||||
fmt.Printf("%s:\n", category.name)
|
fmt.Printf("%s:\n", category.name)
|
||||||
for _, conv := range conversationLines {
|
for _, conv := range conversationLines {
|
||||||
if conversationsPrinted >= count && !all {
|
if conversationsPrinted >= count && !all {
|
||||||
fmt.Printf("%d remaining conversation(s), use --all to view.\n", len(messages)-conversationsPrinted)
|
fmt.Printf("%d remaining message(s), use --all to view.\n", len(conversations)-conversationsPrinted)
|
||||||
break outer
|
break outer
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -15,43 +15,42 @@ func NewCmd(ctx *lmcli.Context) *cobra.Command {
|
|||||||
Short: "Start a new conversation",
|
Short: "Start a new conversation",
|
||||||
Long: `Start a new conversation with the Large Language Model.`,
|
Long: `Start a new conversation with the Large Language Model.`,
|
||||||
RunE: func(cmd *cobra.Command, args []string) error {
|
RunE: func(cmd *cobra.Command, args []string) error {
|
||||||
input := inputFromArgsOrEditor(args, "# Start a new conversation below\n", "")
|
messageContents := inputFromArgsOrEditor(args, "# What would you like to say?\n", "")
|
||||||
if input == "" {
|
if messageContents == "" {
|
||||||
return fmt.Errorf("No message was provided.")
|
return fmt.Errorf("No message was provided.")
|
||||||
}
|
}
|
||||||
|
|
||||||
var messages []model.Message
|
conversation := &model.Conversation{}
|
||||||
|
err := ctx.Store.SaveConversation(conversation)
|
||||||
// TODO: probably just make this part of the conversation
|
if err != nil {
|
||||||
system := ctx.GetSystemPrompt()
|
return fmt.Errorf("Could not save new conversation: %v", err)
|
||||||
if system != "" {
|
|
||||||
messages = append(messages, model.Message{
|
|
||||||
Role: model.MessageRoleSystem,
|
|
||||||
Content: system,
|
|
||||||
})
|
|
||||||
}
|
}
|
||||||
|
|
||||||
messages = append(messages, model.Message{
|
messages := []model.Message{
|
||||||
Role: model.MessageRoleUser,
|
{
|
||||||
Content: input,
|
ConversationID: conversation.ID,
|
||||||
})
|
Role: model.MessageRoleSystem,
|
||||||
|
Content: ctx.GetSystemPrompt(),
|
||||||
conversation, messages, err := ctx.Store.StartConversation(messages...)
|
},
|
||||||
if err != nil {
|
{
|
||||||
return fmt.Errorf("Could not start a new conversation: %v", err)
|
ConversationID: conversation.ID,
|
||||||
|
Role: model.MessageRoleUser,
|
||||||
|
Content: messageContents,
|
||||||
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
cmdutil.HandleReply(ctx, &messages[len(messages)-1], true)
|
cmdutil.HandleConversationReply(ctx, conversation, true, messages...)
|
||||||
|
|
||||||
title, err := cmdutil.GenerateTitle(ctx, messages)
|
title, err := cmdutil.GenerateTitle(ctx, conversation)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
lmcli.Warn("Could not generate title for conversation %s: %v\n", conversation.ShortName.String, err)
|
lmcli.Warn("Could not generate title for conversation: %v\n", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
conversation.Title = title
|
conversation.Title = title
|
||||||
err = ctx.Store.UpdateConversation(conversation)
|
|
||||||
|
err = ctx.Store.SaveConversation(conversation)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
lmcli.Warn("Could not save conversation title: %v\n", err)
|
lmcli.Warn("Could not save conversation after generating title: %v\n", err)
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
},
|
},
|
||||||
|
@ -15,27 +15,22 @@ func PromptCmd(ctx *lmcli.Context) *cobra.Command {
|
|||||||
Short: "Do a one-shot prompt",
|
Short: "Do a one-shot prompt",
|
||||||
Long: `Prompt the Large Language Model and get a response.`,
|
Long: `Prompt the Large Language Model and get a response.`,
|
||||||
RunE: func(cmd *cobra.Command, args []string) error {
|
RunE: func(cmd *cobra.Command, args []string) error {
|
||||||
input := inputFromArgsOrEditor(args, "# Write your prompt below\n", "")
|
message := inputFromArgsOrEditor(args, "# What would you like to say?\n", "")
|
||||||
if input == "" {
|
if message == "" {
|
||||||
return fmt.Errorf("No message was provided.")
|
return fmt.Errorf("No message was provided.")
|
||||||
}
|
}
|
||||||
|
|
||||||
var messages []model.Message
|
messages := []model.Message{
|
||||||
|
{
|
||||||
// TODO: stop supplying system prompt as a message
|
|
||||||
system := ctx.GetSystemPrompt()
|
|
||||||
if system != "" {
|
|
||||||
messages = append(messages, model.Message{
|
|
||||||
Role: model.MessageRoleSystem,
|
Role: model.MessageRoleSystem,
|
||||||
Content: system,
|
Content: ctx.GetSystemPrompt(),
|
||||||
})
|
},
|
||||||
|
{
|
||||||
|
Role: model.MessageRoleUser,
|
||||||
|
Content: message,
|
||||||
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
messages = append(messages, model.Message{
|
|
||||||
Role: model.MessageRoleUser,
|
|
||||||
Content: input,
|
|
||||||
})
|
|
||||||
|
|
||||||
_, err := cmdutil.Prompt(ctx, messages, nil)
|
_, err := cmdutil.Prompt(ctx, messages, nil)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("Error fetching LLM response: %v", err)
|
return fmt.Errorf("Error fetching LLM response: %v", err)
|
||||||
|
@ -24,17 +24,12 @@ func RenameCmd(ctx *lmcli.Context) *cobra.Command {
|
|||||||
RunE: func(cmd *cobra.Command, args []string) error {
|
RunE: func(cmd *cobra.Command, args []string) error {
|
||||||
shortName := args[0]
|
shortName := args[0]
|
||||||
conversation := cmdutil.LookupConversation(ctx, shortName)
|
conversation := cmdutil.LookupConversation(ctx, shortName)
|
||||||
|
|
||||||
var err error
|
var err error
|
||||||
var title string
|
|
||||||
|
|
||||||
generate, _ := cmd.Flags().GetBool("generate")
|
generate, _ := cmd.Flags().GetBool("generate")
|
||||||
|
var title string
|
||||||
if generate {
|
if generate {
|
||||||
messages, err := ctx.Store.PathToLeaf(conversation.SelectedRoot)
|
title, err = cmdutil.GenerateTitle(ctx, conversation)
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("Could not retrieve conversation messages: %v", err)
|
|
||||||
}
|
|
||||||
title, err = cmdutil.GenerateTitle(ctx, messages)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("Could not generate conversation title: %v", err)
|
return fmt.Errorf("Could not generate conversation title: %v", err)
|
||||||
}
|
}
|
||||||
@ -46,9 +41,9 @@ func RenameCmd(ctx *lmcli.Context) *cobra.Command {
|
|||||||
}
|
}
|
||||||
|
|
||||||
conversation.Title = title
|
conversation.Title = title
|
||||||
err = ctx.Store.UpdateConversation(conversation)
|
err = ctx.Store.SaveConversation(conversation)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
lmcli.Warn("Could not update conversation title: %v\n", err)
|
lmcli.Warn("Could not save conversation with new title: %v\n", err)
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
},
|
},
|
||||||
|
@ -31,9 +31,9 @@ func ReplyCmd(ctx *lmcli.Context) *cobra.Command {
|
|||||||
}
|
}
|
||||||
|
|
||||||
cmdutil.HandleConversationReply(ctx, conversation, true, model.Message{
|
cmdutil.HandleConversationReply(ctx, conversation, true, model.Message{
|
||||||
ConversationID: conversation.ID,
|
ConversationID: conversation.ID,
|
||||||
Role: model.MessageRoleUser,
|
Role: model.MessageRoleUser,
|
||||||
Content: reply,
|
Content: reply,
|
||||||
})
|
})
|
||||||
return nil
|
return nil
|
||||||
},
|
},
|
||||||
|
@ -13,7 +13,7 @@ func RetryCmd(ctx *lmcli.Context) *cobra.Command {
|
|||||||
cmd := &cobra.Command{
|
cmd := &cobra.Command{
|
||||||
Use: "retry <conversation>",
|
Use: "retry <conversation>",
|
||||||
Short: "Retry the last user reply in a conversation",
|
Short: "Retry the last user reply in a conversation",
|
||||||
Long: `Prompt the conversation from the last user response.`,
|
Long: `Re-prompt the conversation up to the last user response. Can be used to regenerate the last assistant reply, or simply generate one if an error occurred.`,
|
||||||
Args: func(cmd *cobra.Command, args []string) error {
|
Args: func(cmd *cobra.Command, args []string) error {
|
||||||
argCount := 1
|
argCount := 1
|
||||||
if err := cobra.MinimumNArgs(argCount)(cmd, args); err != nil {
|
if err := cobra.MinimumNArgs(argCount)(cmd, args); err != nil {
|
||||||
@ -25,28 +25,25 @@ func RetryCmd(ctx *lmcli.Context) *cobra.Command {
|
|||||||
shortName := args[0]
|
shortName := args[0]
|
||||||
conversation := cmdutil.LookupConversation(ctx, shortName)
|
conversation := cmdutil.LookupConversation(ctx, shortName)
|
||||||
|
|
||||||
// Load the complete thread from the root message
|
messages, err := ctx.Store.Messages(conversation)
|
||||||
messages, err := ctx.Store.PathToLeaf(conversation.SelectedRoot)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("Could not retrieve messages for conversation: %s", conversation.Title)
|
return fmt.Errorf("Could not retrieve messages for conversation: %s", conversation.Title)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Find the last user message in the conversation
|
// walk backwards through the conversation and delete messages, break
|
||||||
var lastUserMessage *model.Message
|
// when we find the latest user response
|
||||||
var i int
|
for i := len(messages) - 1; i >= 0; i-- {
|
||||||
for i = len(messages) - 1; i >= 0; i-- {
|
|
||||||
if messages[i].Role == model.MessageRoleUser {
|
if messages[i].Role == model.MessageRoleUser {
|
||||||
lastUserMessage = &messages[i]
|
|
||||||
break
|
break
|
||||||
}
|
}
|
||||||
|
|
||||||
|
err = ctx.Store.DeleteMessage(&messages[i])
|
||||||
|
if err != nil {
|
||||||
|
lmcli.Warn("Could not delete previous reply: %v\n", err)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if lastUserMessage == nil {
|
cmdutil.HandleConversationReply(ctx, conversation, true)
|
||||||
return fmt.Errorf("No user message found in the conversation: %s", conversation.Title)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Start a new branch at the last user message
|
|
||||||
cmdutil.HandleReply(ctx, lastUserMessage, true)
|
|
||||||
return nil
|
return nil
|
||||||
},
|
},
|
||||||
ValidArgsFunction: func(cmd *cobra.Command, args []string, toComplete string) ([]string, cobra.ShellCompDirective) {
|
ValidArgsFunction: func(cmd *cobra.Command, args []string, toComplete string) ([]string, cobra.ShellCompDirective) {
|
||||||
|
@ -57,7 +57,7 @@ func LookupConversation(ctx *lmcli.Context, shortName string) *model.Conversatio
|
|||||||
lmcli.Fatal("Could not lookup conversation: %v\n", err)
|
lmcli.Fatal("Could not lookup conversation: %v\n", err)
|
||||||
}
|
}
|
||||||
if c.ID == 0 {
|
if c.ID == 0 {
|
||||||
lmcli.Fatal("Conversation not found: %s\n", shortName)
|
lmcli.Fatal("Conversation not found with short name: %s\n", shortName)
|
||||||
}
|
}
|
||||||
return c
|
return c
|
||||||
}
|
}
|
||||||
@ -68,63 +68,48 @@ func LookupConversationE(ctx *lmcli.Context, shortName string) (*model.Conversat
|
|||||||
return nil, fmt.Errorf("Could not lookup conversation: %v", err)
|
return nil, fmt.Errorf("Could not lookup conversation: %v", err)
|
||||||
}
|
}
|
||||||
if c.ID == 0 {
|
if c.ID == 0 {
|
||||||
return nil, fmt.Errorf("Conversation not found: %s", shortName)
|
return nil, fmt.Errorf("Conversation not found with short name: %s", shortName)
|
||||||
}
|
}
|
||||||
return c, nil
|
return c, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func HandleConversationReply(ctx *lmcli.Context, c *model.Conversation, persist bool, toSend ...model.Message) {
|
|
||||||
messages, err := ctx.Store.PathToLeaf(c.SelectedRoot)
|
|
||||||
if err != nil {
|
|
||||||
lmcli.Fatal("Could not load messages: %v\n", err)
|
|
||||||
}
|
|
||||||
HandleReply(ctx, &messages[len(messages)-1], persist, toSend...)
|
|
||||||
}
|
|
||||||
|
|
||||||
// handleConversationReply handles sending messages to an existing
|
// handleConversationReply handles sending messages to an existing
|
||||||
// conversation, optionally persisting both the sent replies and responses.
|
// conversation, optionally persisting both the sent replies and responses.
|
||||||
func HandleReply(ctx *lmcli.Context, to *model.Message, persist bool, messages ...model.Message) {
|
func HandleConversationReply(ctx *lmcli.Context, c *model.Conversation, persist bool, toSend ...model.Message) {
|
||||||
if to == nil {
|
existing, err := ctx.Store.Messages(c)
|
||||||
lmcli.Fatal("Can't prompt from an empty message.")
|
|
||||||
}
|
|
||||||
|
|
||||||
existing, err := ctx.Store.PathToRoot(to)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
lmcli.Fatal("Could not load messages: %v\n", err)
|
lmcli.Fatal("Could not retrieve messages for conversation: %s\n", c.Title)
|
||||||
}
|
}
|
||||||
|
|
||||||
RenderConversation(ctx, append(existing, messages...), true)
|
if persist {
|
||||||
|
for _, message := range toSend {
|
||||||
var savedReplies []model.Message
|
err = ctx.Store.SaveMessage(&message)
|
||||||
if persist && len(messages) > 0 {
|
if err != nil {
|
||||||
savedReplies, err = ctx.Store.Reply(to, messages...)
|
lmcli.Warn("Could not save %s message: %v\n", message.Role, err)
|
||||||
if err != nil {
|
}
|
||||||
lmcli.Warn("Could not save messages: %v\n", err)
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
allMessages := append(existing, toSend...)
|
||||||
|
|
||||||
|
RenderConversation(ctx, allMessages, true)
|
||||||
|
|
||||||
// render a message header with no contents
|
// render a message header with no contents
|
||||||
RenderMessage(ctx, (&model.Message{Role: model.MessageRoleAssistant}))
|
RenderMessage(ctx, (&model.Message{Role: model.MessageRoleAssistant}))
|
||||||
|
|
||||||
var lastMessage *model.Message
|
|
||||||
|
|
||||||
lastMessage = to
|
|
||||||
if len(savedReplies) > 1 {
|
|
||||||
lastMessage = &savedReplies[len(savedReplies)-1]
|
|
||||||
}
|
|
||||||
|
|
||||||
replyCallback := func(reply model.Message) {
|
replyCallback := func(reply model.Message) {
|
||||||
if !persist {
|
if !persist {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
savedReplies, err = ctx.Store.Reply(lastMessage, reply)
|
|
||||||
|
reply.ConversationID = c.ID
|
||||||
|
err = ctx.Store.SaveMessage(&reply)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
lmcli.Warn("Could not save reply: %v\n", err)
|
lmcli.Warn("Could not save reply: %v\n", err)
|
||||||
}
|
}
|
||||||
lastMessage = &savedReplies[0]
|
|
||||||
}
|
}
|
||||||
|
|
||||||
_, err = Prompt(ctx, append(existing, messages...), replyCallback)
|
_, err = Prompt(ctx, allMessages, replyCallback)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
lmcli.Fatal("Error fetching LLM response: %v\n", err)
|
lmcli.Fatal("Error fetching LLM response: %v\n", err)
|
||||||
}
|
}
|
||||||
@ -149,7 +134,12 @@ func FormatForExternalPrompt(messages []model.Message, system bool) string {
|
|||||||
return sb.String()
|
return sb.String()
|
||||||
}
|
}
|
||||||
|
|
||||||
func GenerateTitle(ctx *lmcli.Context, messages []model.Message) (string, error) {
|
func GenerateTitle(ctx *lmcli.Context, c *model.Conversation) (string, error) {
|
||||||
|
messages, err := ctx.Store.Messages(c)
|
||||||
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
|
||||||
const prompt = `Above is an excerpt from a conversation between a user and AI assistant. Please reply with a short title (no more than 8 words) that reflects the topic of the conversation, read from the user's perspective.
|
const prompt = `Above is an excerpt from a conversation between a user and AI assistant. Please reply with a short title (no more than 8 words) that reflects the topic of the conversation, read from the user's perspective.
|
||||||
|
|
||||||
Example conversation:
|
Example conversation:
|
||||||
@ -196,7 +186,6 @@ Title: A brief introduction
|
|||||||
|
|
||||||
response = strings.TrimPrefix(response, "Title: ")
|
response = strings.TrimPrefix(response, "Title: ")
|
||||||
response = strings.Trim(response, "\"")
|
response = strings.Trim(response, "\"")
|
||||||
response = strings.TrimSpace(response)
|
|
||||||
|
|
||||||
return response, nil
|
return response, nil
|
||||||
}
|
}
|
||||||
|
@ -24,9 +24,9 @@ func ViewCmd(ctx *lmcli.Context) *cobra.Command {
|
|||||||
shortName := args[0]
|
shortName := args[0]
|
||||||
conversation := cmdutil.LookupConversation(ctx, shortName)
|
conversation := cmdutil.LookupConversation(ctx, shortName)
|
||||||
|
|
||||||
messages, err := ctx.Store.PathToLeaf(conversation.SelectedRoot)
|
messages, err := ctx.Store.Messages(conversation)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("Could not retrieve messages for conversation %s: %v", conversation.ShortName.String, err)
|
return fmt.Errorf("Could not retrieve messages for conversation: %s", conversation.Title)
|
||||||
}
|
}
|
||||||
|
|
||||||
cmdutil.RenderConversation(ctx, messages, false)
|
cmdutil.RenderConversation(ctx, messages, false)
|
||||||
|
@ -22,7 +22,6 @@ type Config struct {
|
|||||||
EnabledTools []string `yaml:"enabledTools"`
|
EnabledTools []string `yaml:"enabledTools"`
|
||||||
} `yaml:"tools"`
|
} `yaml:"tools"`
|
||||||
Providers []*struct {
|
Providers []*struct {
|
||||||
Name *string `yaml:"name"`
|
|
||||||
Kind *string `yaml:"kind"`
|
Kind *string `yaml:"kind"`
|
||||||
BaseURL *string `yaml:"baseUrl"`
|
BaseURL *string `yaml:"baseUrl"`
|
||||||
APIKey *string `yaml:"apiKey"`
|
APIKey *string `yaml:"apiKey"`
|
||||||
|
@ -4,12 +4,10 @@ import (
|
|||||||
"fmt"
|
"fmt"
|
||||||
"os"
|
"os"
|
||||||
"path/filepath"
|
"path/filepath"
|
||||||
"strings"
|
|
||||||
|
|
||||||
"git.mlow.ca/mlow/lmcli/pkg/lmcli/model"
|
"git.mlow.ca/mlow/lmcli/pkg/lmcli/model"
|
||||||
"git.mlow.ca/mlow/lmcli/pkg/lmcli/provider"
|
"git.mlow.ca/mlow/lmcli/pkg/lmcli/provider"
|
||||||
"git.mlow.ca/mlow/lmcli/pkg/lmcli/provider/anthropic"
|
"git.mlow.ca/mlow/lmcli/pkg/lmcli/provider/anthropic"
|
||||||
"git.mlow.ca/mlow/lmcli/pkg/lmcli/provider/google"
|
|
||||||
"git.mlow.ca/mlow/lmcli/pkg/lmcli/provider/openai"
|
"git.mlow.ca/mlow/lmcli/pkg/lmcli/provider/openai"
|
||||||
"git.mlow.ca/mlow/lmcli/pkg/lmcli/tools"
|
"git.mlow.ca/mlow/lmcli/pkg/lmcli/tools"
|
||||||
"git.mlow.ca/mlow/lmcli/pkg/util"
|
"git.mlow.ca/mlow/lmcli/pkg/util"
|
||||||
@ -36,9 +34,7 @@ func NewContext() (*Context, error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
databaseFile := filepath.Join(dataDir(), "conversations.db")
|
databaseFile := filepath.Join(dataDir(), "conversations.db")
|
||||||
db, err := gorm.Open(sqlite.Open(databaseFile), &gorm.Config{
|
db, err := gorm.Open(sqlite.Open(databaseFile), &gorm.Config{})
|
||||||
//Logger: logger.Default.LogMode(logger.Info),
|
|
||||||
})
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("Error establishing connection to store: %v", err)
|
return nil, fmt.Errorf("Error establishing connection to store: %v", err)
|
||||||
}
|
}
|
||||||
@ -61,37 +57,16 @@ func NewContext() (*Context, error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (c *Context) GetModels() (models []string) {
|
func (c *Context) GetModels() (models []string) {
|
||||||
modelCounts := make(map[string]int)
|
|
||||||
for _, p := range c.Config.Providers {
|
for _, p := range c.Config.Providers {
|
||||||
for _, m := range *p.Models {
|
for _, m := range *p.Models {
|
||||||
modelCounts[m]++
|
|
||||||
models = append(models, *p.Name+"/"+m)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
for m, c := range modelCounts {
|
|
||||||
if c == 1 {
|
|
||||||
models = append(models, m)
|
models = append(models, m)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *Context) GetCompletionProvider(model string) (provider.ChatCompletionClient, error) {
|
func (c *Context) GetCompletionProvider(model string) (provider.ChatCompletionClient, error) {
|
||||||
parts := strings.Split(model, "/")
|
|
||||||
|
|
||||||
var provider string
|
|
||||||
if len(parts) > 1 {
|
|
||||||
provider = parts[0]
|
|
||||||
model = parts[1]
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, p := range c.Config.Providers {
|
for _, p := range c.Config.Providers {
|
||||||
if provider != "" && *p.Name != provider {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, m := range *p.Models {
|
for _, m := range *p.Models {
|
||||||
if m == model {
|
if m == model {
|
||||||
switch *p.Kind {
|
switch *p.Kind {
|
||||||
@ -100,28 +75,21 @@ func (c *Context) GetCompletionProvider(model string) (provider.ChatCompletionCl
|
|||||||
if p.BaseURL != nil {
|
if p.BaseURL != nil {
|
||||||
url = *p.BaseURL
|
url = *p.BaseURL
|
||||||
}
|
}
|
||||||
return &anthropic.AnthropicClient{
|
anthropic := &anthropic.AnthropicClient{
|
||||||
BaseURL: url,
|
BaseURL: url,
|
||||||
APIKey: *p.APIKey,
|
APIKey: *p.APIKey,
|
||||||
}, nil
|
|
||||||
case "google":
|
|
||||||
url := "https://generativelanguage.googleapis.com"
|
|
||||||
if p.BaseURL != nil {
|
|
||||||
url = *p.BaseURL
|
|
||||||
}
|
}
|
||||||
return &google.Client{
|
return anthropic, nil
|
||||||
BaseURL: url,
|
|
||||||
APIKey: *p.APIKey,
|
|
||||||
}, nil
|
|
||||||
case "openai":
|
case "openai":
|
||||||
url := "https://api.openai.com/v1"
|
url := "https://api.openai.com/v1"
|
||||||
if p.BaseURL != nil {
|
if p.BaseURL != nil {
|
||||||
url = *p.BaseURL
|
url = *p.BaseURL
|
||||||
}
|
}
|
||||||
return &openai.OpenAIClient{
|
openai := &openai.OpenAIClient{
|
||||||
BaseURL: url,
|
BaseURL: url,
|
||||||
APIKey: *p.APIKey,
|
APIKey: *p.APIKey,
|
||||||
}, nil
|
}
|
||||||
|
return openai, nil
|
||||||
default:
|
default:
|
||||||
return nil, fmt.Errorf("unknown provider kind: %s", *p.Kind)
|
return nil, fmt.Errorf("unknown provider kind: %s", *p.Kind)
|
||||||
}
|
}
|
||||||
|
@ -16,28 +16,19 @@ const (
|
|||||||
)
|
)
|
||||||
|
|
||||||
type Message struct {
|
type Message struct {
|
||||||
ID uint `gorm:"primaryKey"`
|
ID uint `gorm:"primaryKey"`
|
||||||
ConversationID uint `gorm:"index"`
|
ConversationID uint `gorm:"foreignKey:ConversationID"`
|
||||||
Conversation Conversation `gorm:"foreignKey:ConversationID"`
|
|
||||||
Content string
|
Content string
|
||||||
Role MessageRole
|
Role MessageRole
|
||||||
CreatedAt time.Time `gorm:"index"`
|
CreatedAt time.Time
|
||||||
ToolCalls ToolCalls // a json array of tool calls (from the model)
|
ToolCalls ToolCalls // a json array of tool calls (from the modl)
|
||||||
ToolResults ToolResults // a json array of tool results
|
ToolResults ToolResults // a json array of tool results
|
||||||
ParentID *uint `gorm:"index"`
|
|
||||||
Parent *Message `gorm:"foreignKey:ParentID"`
|
|
||||||
Replies []Message `gorm:"foreignKey:ParentID"`
|
|
||||||
|
|
||||||
SelectedReplyID *uint
|
|
||||||
SelectedReply *Message `gorm:"foreignKey:SelectedReplyID"`
|
|
||||||
}
|
}
|
||||||
|
|
||||||
type Conversation struct {
|
type Conversation struct {
|
||||||
ID uint `gorm:"primaryKey"`
|
ID uint `gorm:"primaryKey"`
|
||||||
ShortName sql.NullString
|
ShortName sql.NullString
|
||||||
Title string
|
Title string
|
||||||
SelectedRootID *uint
|
|
||||||
SelectedRoot *Message `gorm:"foreignKey:SelectedRootID"`
|
|
||||||
}
|
}
|
||||||
|
|
||||||
type RequestParameters struct {
|
type RequestParameters struct {
|
||||||
|
@ -1,413 +0,0 @@
|
|||||||
package google
|
|
||||||
|
|
||||||
import (
|
|
||||||
"bufio"
|
|
||||||
"bytes"
|
|
||||||
"context"
|
|
||||||
"encoding/json"
|
|
||||||
"fmt"
|
|
||||||
"io"
|
|
||||||
"net/http"
|
|
||||||
"strings"
|
|
||||||
|
|
||||||
"git.mlow.ca/mlow/lmcli/pkg/lmcli/model"
|
|
||||||
"git.mlow.ca/mlow/lmcli/pkg/lmcli/provider"
|
|
||||||
"git.mlow.ca/mlow/lmcli/pkg/lmcli/tools"
|
|
||||||
)
|
|
||||||
|
|
||||||
func convertTools(tools []model.Tool) []Tool {
|
|
||||||
geminiTools := make([]Tool, len(tools))
|
|
||||||
for i, tool := range tools {
|
|
||||||
params := make(map[string]ToolParameter)
|
|
||||||
var required []string
|
|
||||||
|
|
||||||
for _, param := range tool.Parameters {
|
|
||||||
// TODO: proper enum handing
|
|
||||||
params[param.Name] = ToolParameter{
|
|
||||||
Type: param.Type,
|
|
||||||
Description: param.Description,
|
|
||||||
Values: param.Enum,
|
|
||||||
}
|
|
||||||
if param.Required {
|
|
||||||
required = append(required, param.Name)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
geminiTools[i] = Tool{
|
|
||||||
FunctionDeclarations: []FunctionDeclaration{
|
|
||||||
{
|
|
||||||
Name: tool.Name,
|
|
||||||
Description: tool.Description,
|
|
||||||
Parameters: ToolParameters{
|
|
||||||
Type: "OBJECT",
|
|
||||||
Properties: params,
|
|
||||||
Required: required,
|
|
||||||
},
|
|
||||||
},
|
|
||||||
},
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return geminiTools
|
|
||||||
}
|
|
||||||
|
|
||||||
func convertToolCallToGemini(toolCalls []model.ToolCall) []ContentPart {
|
|
||||||
converted := make([]ContentPart, len(toolCalls))
|
|
||||||
for i, call := range toolCalls {
|
|
||||||
args := make(map[string]string)
|
|
||||||
for k, v := range call.Parameters {
|
|
||||||
args[k] = fmt.Sprintf("%v", v)
|
|
||||||
}
|
|
||||||
converted[i].FunctionCall = &FunctionCall{
|
|
||||||
Name: call.Name,
|
|
||||||
Args: args,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return converted
|
|
||||||
}
|
|
||||||
|
|
||||||
func convertToolCallToAPI(functionCalls []FunctionCall) []model.ToolCall {
|
|
||||||
converted := make([]model.ToolCall, len(functionCalls))
|
|
||||||
for i, call := range functionCalls {
|
|
||||||
params := make(map[string]interface{})
|
|
||||||
for k, v := range call.Args {
|
|
||||||
params[k] = v
|
|
||||||
}
|
|
||||||
converted[i].Name = call.Name
|
|
||||||
converted[i].Parameters = params
|
|
||||||
}
|
|
||||||
return converted
|
|
||||||
}
|
|
||||||
|
|
||||||
func convertToolResultsToGemini(toolResults []model.ToolResult) ([]FunctionResponse, error) {
|
|
||||||
results := make([]FunctionResponse, len(toolResults))
|
|
||||||
for i, result := range toolResults {
|
|
||||||
var obj interface{}
|
|
||||||
err := json.Unmarshal([]byte(result.Result), &obj)
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("Could not unmarshal %s: %v", result.Result, err)
|
|
||||||
}
|
|
||||||
results[i] = FunctionResponse{
|
|
||||||
Name: result.ToolName,
|
|
||||||
Response: obj,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return results, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func createGenerateContentRequest(
|
|
||||||
params model.RequestParameters,
|
|
||||||
messages []model.Message,
|
|
||||||
) (*GenerateContentRequest, error) {
|
|
||||||
requestContents := make([]Content, 0, len(messages))
|
|
||||||
|
|
||||||
startIdx := 0
|
|
||||||
var system string
|
|
||||||
if len(messages) > 0 && messages[0].Role == model.MessageRoleSystem {
|
|
||||||
system = messages[0].Content
|
|
||||||
startIdx = 1
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, m := range messages[startIdx:] {
|
|
||||||
switch m.Role {
|
|
||||||
case "tool_call":
|
|
||||||
content := Content{
|
|
||||||
Role: "model",
|
|
||||||
Parts: convertToolCallToGemini(m.ToolCalls),
|
|
||||||
}
|
|
||||||
requestContents = append(requestContents, content)
|
|
||||||
case "tool_result":
|
|
||||||
results, err := convertToolResultsToGemini(m.ToolResults)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
// expand tool_result messages' results into multiple gemini messages
|
|
||||||
for _, result := range results {
|
|
||||||
content := Content{
|
|
||||||
Role: "function",
|
|
||||||
Parts: []ContentPart{
|
|
||||||
{
|
|
||||||
FunctionResp: &result,
|
|
||||||
},
|
|
||||||
},
|
|
||||||
}
|
|
||||||
requestContents = append(requestContents, content)
|
|
||||||
}
|
|
||||||
default:
|
|
||||||
var role string
|
|
||||||
switch m.Role {
|
|
||||||
case model.MessageRoleAssistant:
|
|
||||||
role = "model"
|
|
||||||
case model.MessageRoleUser:
|
|
||||||
role = "user"
|
|
||||||
}
|
|
||||||
|
|
||||||
if role == "" {
|
|
||||||
panic("Unhandled role: " + m.Role)
|
|
||||||
}
|
|
||||||
|
|
||||||
content := Content{
|
|
||||||
Role: role,
|
|
||||||
Parts: []ContentPart{
|
|
||||||
{
|
|
||||||
Text: m.Content,
|
|
||||||
},
|
|
||||||
},
|
|
||||||
}
|
|
||||||
requestContents = append(requestContents, content)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
request := &GenerateContentRequest{
|
|
||||||
Contents: requestContents,
|
|
||||||
SystemInstructions: system,
|
|
||||||
GenerationConfig: &GenerationConfig{
|
|
||||||
MaxOutputTokens: ¶ms.MaxTokens,
|
|
||||||
Temperature: ¶ms.Temperature,
|
|
||||||
TopP: ¶ms.TopP,
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
if len(params.ToolBag) > 0 {
|
|
||||||
request.Tools = convertTools(params.ToolBag)
|
|
||||||
}
|
|
||||||
|
|
||||||
return request, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func handleToolCalls(
|
|
||||||
params model.RequestParameters,
|
|
||||||
content string,
|
|
||||||
toolCalls []model.ToolCall,
|
|
||||||
callback provider.ReplyCallback,
|
|
||||||
messages []model.Message,
|
|
||||||
) ([]model.Message, error) {
|
|
||||||
lastMessage := messages[len(messages)-1]
|
|
||||||
continuation := false
|
|
||||||
if lastMessage.Role.IsAssistant() {
|
|
||||||
continuation = true
|
|
||||||
}
|
|
||||||
|
|
||||||
toolCall := model.Message{
|
|
||||||
Role: model.MessageRoleToolCall,
|
|
||||||
Content: content,
|
|
||||||
ToolCalls: toolCalls,
|
|
||||||
}
|
|
||||||
|
|
||||||
toolResults, err := tools.ExecuteToolCalls(toolCall.ToolCalls, params.ToolBag)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
toolResult := model.Message{
|
|
||||||
Role: model.MessageRoleToolResult,
|
|
||||||
ToolResults: toolResults,
|
|
||||||
}
|
|
||||||
|
|
||||||
if callback != nil {
|
|
||||||
callback(toolCall)
|
|
||||||
callback(toolResult)
|
|
||||||
}
|
|
||||||
|
|
||||||
if continuation {
|
|
||||||
messages[len(messages)-1] = toolCall
|
|
||||||
} else {
|
|
||||||
messages = append(messages, toolCall)
|
|
||||||
}
|
|
||||||
messages = append(messages, toolResult)
|
|
||||||
|
|
||||||
return messages, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *Client) sendRequest(ctx context.Context, req *http.Request) (*http.Response, error) {
|
|
||||||
req.Header.Set("Content-Type", "application/json")
|
|
||||||
|
|
||||||
client := &http.Client{}
|
|
||||||
resp, err := client.Do(req.WithContext(ctx))
|
|
||||||
|
|
||||||
if resp.StatusCode != 200 {
|
|
||||||
bytes, _ := io.ReadAll(resp.Body)
|
|
||||||
return resp, fmt.Errorf("%v", string(bytes))
|
|
||||||
}
|
|
||||||
|
|
||||||
return resp, err
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *Client) CreateChatCompletion(
|
|
||||||
ctx context.Context,
|
|
||||||
params model.RequestParameters,
|
|
||||||
messages []model.Message,
|
|
||||||
callback provider.ReplyCallback,
|
|
||||||
) (string, error) {
|
|
||||||
if len(messages) == 0 {
|
|
||||||
return "", fmt.Errorf("Can't create completion from no messages")
|
|
||||||
}
|
|
||||||
|
|
||||||
req, err := createGenerateContentRequest(params, messages)
|
|
||||||
if err != nil {
|
|
||||||
return "", err
|
|
||||||
}
|
|
||||||
jsonData, err := json.Marshal(req)
|
|
||||||
if err != nil {
|
|
||||||
return "", err
|
|
||||||
}
|
|
||||||
|
|
||||||
url := fmt.Sprintf(
|
|
||||||
"%s/v1beta/models/%s:generateContent?key=%s",
|
|
||||||
c.BaseURL, params.Model, c.APIKey,
|
|
||||||
)
|
|
||||||
httpReq, err := http.NewRequest("POST", url, bytes.NewBuffer(jsonData))
|
|
||||||
if err != nil {
|
|
||||||
return "", err
|
|
||||||
}
|
|
||||||
|
|
||||||
resp, err := c.sendRequest(ctx, httpReq)
|
|
||||||
if err != nil {
|
|
||||||
return "", err
|
|
||||||
}
|
|
||||||
defer resp.Body.Close()
|
|
||||||
|
|
||||||
var completionResp GenerateContentResponse
|
|
||||||
err = json.NewDecoder(resp.Body).Decode(&completionResp)
|
|
||||||
if err != nil {
|
|
||||||
return "", err
|
|
||||||
}
|
|
||||||
|
|
||||||
choice := completionResp.Candidates[0]
|
|
||||||
|
|
||||||
var content string
|
|
||||||
lastMessage := messages[len(messages)-1]
|
|
||||||
if lastMessage.Role.IsAssistant() {
|
|
||||||
content = lastMessage.Content
|
|
||||||
}
|
|
||||||
|
|
||||||
var toolCalls []FunctionCall
|
|
||||||
for _, part := range choice.Content.Parts {
|
|
||||||
if part.Text != "" {
|
|
||||||
content += part.Text
|
|
||||||
}
|
|
||||||
|
|
||||||
if part.FunctionCall != nil {
|
|
||||||
toolCalls = append(toolCalls, *part.FunctionCall)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if len(toolCalls) > 0 {
|
|
||||||
messages, err := handleToolCalls(
|
|
||||||
params, content, convertToolCallToAPI(toolCalls), callback, messages,
|
|
||||||
)
|
|
||||||
if err != nil {
|
|
||||||
return content, err
|
|
||||||
}
|
|
||||||
|
|
||||||
return c.CreateChatCompletion(ctx, params, messages, callback)
|
|
||||||
}
|
|
||||||
|
|
||||||
if callback != nil {
|
|
||||||
callback(model.Message{
|
|
||||||
Role: model.MessageRoleAssistant,
|
|
||||||
Content: content,
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
return content, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *Client) CreateChatCompletionStream(
|
|
||||||
ctx context.Context,
|
|
||||||
params model.RequestParameters,
|
|
||||||
messages []model.Message,
|
|
||||||
callback provider.ReplyCallback,
|
|
||||||
output chan<- string,
|
|
||||||
) (string, error) {
|
|
||||||
if len(messages) == 0 {
|
|
||||||
return "", fmt.Errorf("Can't create completion from no messages")
|
|
||||||
}
|
|
||||||
|
|
||||||
req, err := createGenerateContentRequest(params, messages)
|
|
||||||
if err != nil {
|
|
||||||
return "", err
|
|
||||||
}
|
|
||||||
jsonData, err := json.Marshal(req)
|
|
||||||
if err != nil {
|
|
||||||
return "", err
|
|
||||||
}
|
|
||||||
|
|
||||||
url := fmt.Sprintf(
|
|
||||||
"%s/v1beta/models/%s:streamGenerateContent?key=%s&alt=sse",
|
|
||||||
c.BaseURL, params.Model, c.APIKey,
|
|
||||||
)
|
|
||||||
httpReq, err := http.NewRequest("POST", url, bytes.NewBuffer(jsonData))
|
|
||||||
if err != nil {
|
|
||||||
return "", err
|
|
||||||
}
|
|
||||||
|
|
||||||
resp, err := c.sendRequest(ctx, httpReq)
|
|
||||||
if err != nil {
|
|
||||||
return "", err
|
|
||||||
}
|
|
||||||
defer resp.Body.Close()
|
|
||||||
|
|
||||||
content := strings.Builder{}
|
|
||||||
|
|
||||||
lastMessage := messages[len(messages)-1]
|
|
||||||
if lastMessage.Role.IsAssistant() {
|
|
||||||
content.WriteString(lastMessage.Content)
|
|
||||||
}
|
|
||||||
|
|
||||||
var toolCalls []FunctionCall
|
|
||||||
|
|
||||||
reader := bufio.NewReader(resp.Body)
|
|
||||||
for {
|
|
||||||
line, err := reader.ReadBytes('\n')
|
|
||||||
if err != nil {
|
|
||||||
if err == io.EOF {
|
|
||||||
break
|
|
||||||
}
|
|
||||||
return "", err
|
|
||||||
}
|
|
||||||
|
|
||||||
line = bytes.TrimSpace(line)
|
|
||||||
if len(line) == 0 || !bytes.HasPrefix(line, []byte("data: ")) {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
line = bytes.TrimPrefix(line, []byte("data: "))
|
|
||||||
|
|
||||||
var streamResp GenerateContentResponse
|
|
||||||
err = json.Unmarshal(line, &streamResp)
|
|
||||||
if err != nil {
|
|
||||||
return "", err
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, candidate := range streamResp.Candidates {
|
|
||||||
for _, part := range candidate.Content.Parts {
|
|
||||||
if part.FunctionCall != nil {
|
|
||||||
toolCalls = append(toolCalls, *part.FunctionCall)
|
|
||||||
} else if part.Text != "" {
|
|
||||||
output <- part.Text
|
|
||||||
content.WriteString(part.Text)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// If there are function calls, handle them and recurse
|
|
||||||
if len(toolCalls) > 0 {
|
|
||||||
messages, err := handleToolCalls(
|
|
||||||
params, content.String(), convertToolCallToAPI(toolCalls), callback, messages,
|
|
||||||
)
|
|
||||||
if err != nil {
|
|
||||||
return content.String(), err
|
|
||||||
}
|
|
||||||
return c.CreateChatCompletionStream(ctx, params, messages, callback, output)
|
|
||||||
}
|
|
||||||
|
|
||||||
if callback != nil {
|
|
||||||
callback(model.Message{
|
|
||||||
Role: model.MessageRoleAssistant,
|
|
||||||
Content: content.String(),
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
return content.String(), nil
|
|
||||||
}
|
|
@ -1,80 +0,0 @@
|
|||||||
package google
|
|
||||||
|
|
||||||
type Client struct {
|
|
||||||
APIKey string
|
|
||||||
BaseURL string
|
|
||||||
}
|
|
||||||
|
|
||||||
type ContentPart struct {
|
|
||||||
Text string `json:"text,omitempty"`
|
|
||||||
FunctionCall *FunctionCall `json:"functionCall,omitempty"`
|
|
||||||
FunctionResp *FunctionResponse `json:"functionResponse,omitempty"`
|
|
||||||
}
|
|
||||||
|
|
||||||
type FunctionCall struct {
|
|
||||||
Name string `json:"name"`
|
|
||||||
Args map[string]string `json:"args"`
|
|
||||||
}
|
|
||||||
|
|
||||||
type FunctionResponse struct {
|
|
||||||
Name string `json:"name"`
|
|
||||||
Response interface{} `json:"response"`
|
|
||||||
}
|
|
||||||
|
|
||||||
type Content struct {
|
|
||||||
Role string `json:"role"`
|
|
||||||
Parts []ContentPart `json:"parts"`
|
|
||||||
}
|
|
||||||
|
|
||||||
type GenerationConfig struct {
|
|
||||||
MaxOutputTokens *int `json:"maxOutputTokens,omitempty"`
|
|
||||||
Temperature *float32 `json:"temperature,omitempty"`
|
|
||||||
TopP *float32 `json:"topP,omitempty"`
|
|
||||||
TopK *int `json:"topK,omitempty"`
|
|
||||||
}
|
|
||||||
|
|
||||||
type GenerateContentRequest struct {
|
|
||||||
Contents []Content `json:"contents"`
|
|
||||||
Tools []Tool `json:"tools,omitempty"`
|
|
||||||
SystemInstructions string `json:"systemInstructions,omitempty"`
|
|
||||||
GenerationConfig *GenerationConfig `json:"generationConfig,omitempty"`
|
|
||||||
}
|
|
||||||
|
|
||||||
type Candidate struct {
|
|
||||||
Content Content `json:"content"`
|
|
||||||
FinishReason string `json:"finishReason"`
|
|
||||||
Index int `json:"index"`
|
|
||||||
}
|
|
||||||
|
|
||||||
type UsageMetadata struct {
|
|
||||||
PromptTokenCount int `json:"promptTokenCount"`
|
|
||||||
CandidatesTokenCount int `json:"candidatesTokenCount"`
|
|
||||||
TotalTokenCount int `json:"totalTokenCount"`
|
|
||||||
}
|
|
||||||
|
|
||||||
type GenerateContentResponse struct {
|
|
||||||
Candidates []Candidate `json:"candidates"`
|
|
||||||
UsageMetadata UsageMetadata `json:"usageMetadata"`
|
|
||||||
}
|
|
||||||
|
|
||||||
type Tool struct {
|
|
||||||
FunctionDeclarations []FunctionDeclaration `json:"functionDeclarations"`
|
|
||||||
}
|
|
||||||
|
|
||||||
type FunctionDeclaration struct {
|
|
||||||
Name string `json:"name"`
|
|
||||||
Description string `json:"description"`
|
|
||||||
Parameters ToolParameters `json:"parameters"`
|
|
||||||
}
|
|
||||||
|
|
||||||
type ToolParameters struct {
|
|
||||||
Type string `json:"type"`
|
|
||||||
Properties map[string]ToolParameter `json:"properties,omitempty"`
|
|
||||||
Required []string `json:"required,omitempty"`
|
|
||||||
}
|
|
||||||
|
|
||||||
type ToolParameter struct {
|
|
||||||
Type string `json:"type"`
|
|
||||||
Description string `json:"description"`
|
|
||||||
Values []string `json:"values,omitempty"`
|
|
||||||
}
|
|
@ -13,26 +13,21 @@ import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
type ConversationStore interface {
|
type ConversationStore interface {
|
||||||
|
Conversations() ([]model.Conversation, error)
|
||||||
|
|
||||||
ConversationByShortName(shortName string) (*model.Conversation, error)
|
ConversationByShortName(shortName string) (*model.Conversation, error)
|
||||||
ConversationShortNameCompletions(search string) []string
|
ConversationShortNameCompletions(search string) []string
|
||||||
RootMessages(conversationID uint) ([]model.Message, error)
|
|
||||||
LatestConversationMessages() ([]model.Message, error)
|
|
||||||
|
|
||||||
StartConversation(messages ...model.Message) (*model.Conversation, []model.Message, error)
|
SaveConversation(conversation *model.Conversation) error
|
||||||
UpdateConversation(conversation *model.Conversation) error
|
|
||||||
DeleteConversation(conversation *model.Conversation) error
|
DeleteConversation(conversation *model.Conversation) error
|
||||||
CloneConversation(toClone model.Conversation) (*model.Conversation, uint, error)
|
|
||||||
|
|
||||||
MessageByID(messageID uint) (*model.Message, error)
|
Messages(conversation *model.Conversation) ([]model.Message, error)
|
||||||
MessageReplies(messageID uint) ([]model.Message, error)
|
LastMessage(conversation *model.Conversation) (*model.Message, error)
|
||||||
|
|
||||||
|
SaveMessage(message *model.Message) error
|
||||||
|
DeleteMessage(message *model.Message) error
|
||||||
UpdateMessage(message *model.Message) error
|
UpdateMessage(message *model.Message) error
|
||||||
DeleteMessage(message *model.Message, prune bool) error
|
AddReply(conversation *model.Conversation, message model.Message) (*model.Message, error)
|
||||||
CloneBranch(toClone model.Message) (*model.Message, uint, error)
|
|
||||||
Reply(to *model.Message, messages ...model.Message) ([]model.Message, error)
|
|
||||||
|
|
||||||
PathToRoot(message *model.Message) ([]model.Message, error)
|
|
||||||
PathToLeaf(message *model.Message) ([]model.Message, error)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
type SQLStore struct {
|
type SQLStore struct {
|
||||||
@ -57,52 +52,47 @@ func NewSQLStore(db *gorm.DB) (*SQLStore, error) {
|
|||||||
return &SQLStore{db, _sqids}, nil
|
return &SQLStore{db, _sqids}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *SQLStore) saveNewConversation(c *model.Conversation) error {
|
func (s *SQLStore) SaveConversation(conversation *model.Conversation) error {
|
||||||
// Save the new conversation
|
err := s.db.Save(&conversation).Error
|
||||||
err := s.db.Save(&c).Error
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
// Generate and save its "short name"
|
if !conversation.ShortName.Valid {
|
||||||
shortName, _ := s.sqids.Encode([]uint64{uint64(c.ID)})
|
shortName, _ := s.sqids.Encode([]uint64{uint64(conversation.ID)})
|
||||||
c.ShortName = sql.NullString{String: shortName, Valid: true}
|
conversation.ShortName = sql.NullString{String: shortName, Valid: true}
|
||||||
return s.UpdateConversation(c)
|
err = s.db.Save(&conversation).Error
|
||||||
}
|
|
||||||
|
|
||||||
func (s *SQLStore) UpdateConversation(c *model.Conversation) error {
|
|
||||||
if c == nil || c.ID == 0 {
|
|
||||||
return fmt.Errorf("Conversation is nil or invalid (missing ID)")
|
|
||||||
}
|
}
|
||||||
return s.db.Updates(&c).Error
|
|
||||||
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *SQLStore) DeleteConversation(c *model.Conversation) error {
|
func (s *SQLStore) DeleteConversation(conversation *model.Conversation) error {
|
||||||
// Delete messages first
|
s.db.Where("conversation_id = ?", conversation.ID).Delete(&model.Message{})
|
||||||
err := s.db.Where("conversation_id = ?", c.ID).Delete(&model.Message{}).Error
|
return s.db.Delete(&conversation).Error
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
return s.db.Delete(&c).Error
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *SQLStore) DeleteMessage(message *model.Message, prune bool) error {
|
func (s *SQLStore) SaveMessage(message *model.Message) error {
|
||||||
panic("Not yet implemented")
|
return s.db.Create(message).Error
|
||||||
//return s.db.Delete(&message).Error
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *SQLStore) UpdateMessage(m *model.Message) error {
|
func (s *SQLStore) DeleteMessage(message *model.Message) error {
|
||||||
if m == nil || m.ID == 0 {
|
return s.db.Delete(&message).Error
|
||||||
return fmt.Errorf("Message is nil or invalid (missing ID)")
|
}
|
||||||
}
|
|
||||||
return s.db.Updates(&m).Error
|
func (s *SQLStore) UpdateMessage(message *model.Message) error {
|
||||||
|
return s.db.Updates(&message).Error
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *SQLStore) Conversations() ([]model.Conversation, error) {
|
||||||
|
var conversations []model.Conversation
|
||||||
|
err := s.db.Find(&conversations).Error
|
||||||
|
return conversations, err
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *SQLStore) ConversationShortNameCompletions(shortName string) []string {
|
func (s *SQLStore) ConversationShortNameCompletions(shortName string) []string {
|
||||||
var conversations []model.Conversation
|
var completions []string
|
||||||
// ignore error for completions
|
conversations, _ := s.Conversations() // ignore error for completions
|
||||||
s.db.Find(&conversations)
|
|
||||||
completions := make([]string, 0, len(conversations))
|
|
||||||
for _, conversation := range conversations {
|
for _, conversation := range conversations {
|
||||||
if shortName == "" || strings.HasPrefix(conversation.ShortName.String, shortName) {
|
if shortName == "" || strings.HasPrefix(conversation.ShortName.String, shortName) {
|
||||||
completions = append(completions, fmt.Sprintf("%s\t%s", conversation.ShortName.String, conversation.Title))
|
completions = append(completions, fmt.Sprintf("%s\t%s", conversation.ShortName.String, conversation.Title))
|
||||||
@ -116,249 +106,27 @@ func (s *SQLStore) ConversationByShortName(shortName string) (*model.Conversatio
|
|||||||
return nil, errors.New("shortName is empty")
|
return nil, errors.New("shortName is empty")
|
||||||
}
|
}
|
||||||
var conversation model.Conversation
|
var conversation model.Conversation
|
||||||
err := s.db.Preload("SelectedRoot").Where("short_name = ?", shortName).Find(&conversation).Error
|
err := s.db.Where("short_name = ?", shortName).Find(&conversation).Error
|
||||||
return &conversation, err
|
return &conversation, err
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *SQLStore) RootMessages(conversationID uint) ([]model.Message, error) {
|
func (s *SQLStore) Messages(conversation *model.Conversation) ([]model.Message, error) {
|
||||||
var rootMessages []model.Message
|
var messages []model.Message
|
||||||
err := s.db.Where("conversation_id = ? AND parent_id IS NULL", conversationID).Find(&rootMessages).Error
|
err := s.db.Where("conversation_id = ?", conversation.ID).Find(&messages).Error
|
||||||
if err != nil {
|
return messages, err
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
return rootMessages, nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *SQLStore) MessageByID(messageID uint) (*model.Message, error) {
|
func (s *SQLStore) LastMessage(conversation *model.Conversation) (*model.Message, error) {
|
||||||
var message model.Message
|
var message model.Message
|
||||||
err := s.db.Preload("Parent").Preload("Replies").Preload("SelectedReply").Where("id = ?", messageID).Find(&message).Error
|
err := s.db.Where("conversation_id = ?", conversation.ID).Last(&message).Error
|
||||||
return &message, err
|
return &message, err
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *SQLStore) MessageReplies(messageID uint) ([]model.Message, error) {
|
// AddReply adds the given messages as a reply to the given conversation, can be
|
||||||
var replies []model.Message
|
// used to easily copy a message associated with one conversation, to another
|
||||||
err := s.db.Where("parent_id = ?", messageID).Find(&replies).Error
|
func (s *SQLStore) AddReply(c *model.Conversation, m model.Message) (*model.Message, error) {
|
||||||
return replies, err
|
m.ConversationID = c.ID
|
||||||
}
|
m.ID = 0
|
||||||
|
m.CreatedAt = time.Time{}
|
||||||
// StartConversation starts a new conversation with the provided messages
|
return &m, s.SaveMessage(&m)
|
||||||
func (s *SQLStore) StartConversation(messages ...model.Message) (*model.Conversation, []model.Message, error) {
|
|
||||||
if len(messages) == 0 {
|
|
||||||
return nil, nil, fmt.Errorf("Must provide at least 1 message")
|
|
||||||
}
|
|
||||||
|
|
||||||
// Create new conversation
|
|
||||||
conversation := &model.Conversation{}
|
|
||||||
err := s.saveNewConversation(conversation)
|
|
||||||
if err != nil {
|
|
||||||
return nil, nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
// Create first message
|
|
||||||
messages[0].ConversationID = conversation.ID
|
|
||||||
err = s.db.Create(&messages[0]).Error
|
|
||||||
if err != nil {
|
|
||||||
return nil, nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
// Update conversation's selected root message
|
|
||||||
conversation.SelectedRoot = &messages[0]
|
|
||||||
err = s.UpdateConversation(conversation)
|
|
||||||
if err != nil {
|
|
||||||
return nil, nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
// Add additional replies to conversation
|
|
||||||
if len(messages) > 1 {
|
|
||||||
newMessages, err := s.Reply(&messages[0], messages[1:]...)
|
|
||||||
if err != nil {
|
|
||||||
return nil, nil, err
|
|
||||||
}
|
|
||||||
messages = append([]model.Message{messages[0]}, newMessages...)
|
|
||||||
}
|
|
||||||
return conversation, messages, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// CloneConversation clones the given conversation and all of its root meesages
|
|
||||||
func (s *SQLStore) CloneConversation(toClone model.Conversation) (*model.Conversation, uint, error) {
|
|
||||||
rootMessages, err := s.RootMessages(toClone.ID)
|
|
||||||
if err != nil {
|
|
||||||
return nil, 0, err
|
|
||||||
}
|
|
||||||
|
|
||||||
clone := &model.Conversation{
|
|
||||||
Title: toClone.Title + " - Clone",
|
|
||||||
}
|
|
||||||
if err := s.saveNewConversation(clone); err != nil {
|
|
||||||
return nil, 0, fmt.Errorf("Could not create clone: %s", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
var errors []error
|
|
||||||
var messageCnt uint = 0
|
|
||||||
for _, root := range rootMessages {
|
|
||||||
messageCnt++
|
|
||||||
newRoot := root
|
|
||||||
newRoot.ConversationID = clone.ID
|
|
||||||
|
|
||||||
cloned, count, err := s.CloneBranch(newRoot)
|
|
||||||
if err != nil {
|
|
||||||
errors = append(errors, err)
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
messageCnt += count
|
|
||||||
|
|
||||||
if root.ID == *toClone.SelectedRootID {
|
|
||||||
clone.SelectedRootID = &cloned.ID
|
|
||||||
if err := s.UpdateConversation(clone); err != nil {
|
|
||||||
errors = append(errors, fmt.Errorf("Could not set selected root on clone: %v", err))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if len(errors) > 0 {
|
|
||||||
return nil, 0, fmt.Errorf("Messages failed to be cloned: %v", errors)
|
|
||||||
}
|
|
||||||
|
|
||||||
return clone, messageCnt, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// Reply replies to the given parentMessage with a series of messages
|
|
||||||
func (s *SQLStore) Reply(parentMessage *model.Message, messages ...model.Message) ([]model.Message, error) {
|
|
||||||
var savedMessages []model.Message
|
|
||||||
currentParent := parentMessage
|
|
||||||
|
|
||||||
err := s.db.Transaction(func(tx *gorm.DB) error {
|
|
||||||
for i := range messages {
|
|
||||||
message := messages[i]
|
|
||||||
message.ConversationID = currentParent.ConversationID
|
|
||||||
message.ParentID = ¤tParent.ID
|
|
||||||
message.ID = 0
|
|
||||||
message.CreatedAt = time.Time{}
|
|
||||||
|
|
||||||
if err := tx.Create(&message).Error; err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
// update parent selected reply
|
|
||||||
currentParent.SelectedReply = &message
|
|
||||||
if err := tx.Model(currentParent).Update("selected_reply_id", message.ID).Error; err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
savedMessages = append(savedMessages, message)
|
|
||||||
currentParent = &message
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
})
|
|
||||||
|
|
||||||
return savedMessages, err
|
|
||||||
}
|
|
||||||
|
|
||||||
// CloneBranch returns a deep clone of the given message and its replies, returning
|
|
||||||
// a new message object. The new message will be attached to the same parent as
|
|
||||||
// the message to clone.
|
|
||||||
func (s *SQLStore) CloneBranch(messageToClone model.Message) (*model.Message, uint, error) {
|
|
||||||
newMessage := messageToClone
|
|
||||||
newMessage.ID = 0
|
|
||||||
newMessage.Replies = nil
|
|
||||||
newMessage.SelectedReplyID = nil
|
|
||||||
newMessage.SelectedReply = nil
|
|
||||||
|
|
||||||
if err := s.db.Create(&newMessage).Error; err != nil {
|
|
||||||
return nil, 0, fmt.Errorf("Could not clone message: %s", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
originalReplies, err := s.MessageReplies(messageToClone.ID)
|
|
||||||
if err != nil {
|
|
||||||
return nil, 0, fmt.Errorf("Could not fetch message %d replies: %v", messageToClone.ID, err)
|
|
||||||
}
|
|
||||||
|
|
||||||
var replyCount uint = 0
|
|
||||||
for _, reply := range originalReplies {
|
|
||||||
replyCount++
|
|
||||||
|
|
||||||
newReply := reply
|
|
||||||
newReply.ConversationID = messageToClone.ConversationID
|
|
||||||
newReply.ParentID = &newMessage.ID
|
|
||||||
newReply.Parent = &newMessage
|
|
||||||
|
|
||||||
res, c, err := s.CloneBranch(newReply)
|
|
||||||
if err != nil {
|
|
||||||
return nil, 0, err
|
|
||||||
}
|
|
||||||
newMessage.Replies = append(newMessage.Replies, *res)
|
|
||||||
replyCount += c
|
|
||||||
|
|
||||||
if reply.ID == *messageToClone.SelectedReplyID {
|
|
||||||
newMessage.SelectedReplyID = &res.ID
|
|
||||||
if err := s.UpdateMessage(&newMessage); err != nil {
|
|
||||||
return nil, 0, fmt.Errorf("Could not update parent select reply ID: %v", err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return &newMessage, replyCount, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// PathToRoot traverses message Parent until reaching the tree root
|
|
||||||
func (s *SQLStore) PathToRoot(message *model.Message) ([]model.Message, error) {
|
|
||||||
if message == nil {
|
|
||||||
return nil, fmt.Errorf("Message is nil")
|
|
||||||
}
|
|
||||||
var path []model.Message
|
|
||||||
current := message
|
|
||||||
for {
|
|
||||||
path = append([]model.Message{*current}, path...)
|
|
||||||
if current.Parent == nil {
|
|
||||||
break
|
|
||||||
}
|
|
||||||
|
|
||||||
var err error
|
|
||||||
current, err = s.MessageByID(*current.ParentID)
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("finding parent message: %w", err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return path, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// PathToLeaf traverses message SelectedReply until reaching a tree leaf
|
|
||||||
func (s *SQLStore) PathToLeaf(message *model.Message) ([]model.Message, error) {
|
|
||||||
if message == nil {
|
|
||||||
return nil, fmt.Errorf("Message is nil")
|
|
||||||
}
|
|
||||||
var path []model.Message
|
|
||||||
current := message
|
|
||||||
for {
|
|
||||||
path = append(path, *current)
|
|
||||||
if current.SelectedReplyID == nil {
|
|
||||||
break
|
|
||||||
}
|
|
||||||
|
|
||||||
var err error
|
|
||||||
current, err = s.MessageByID(*current.SelectedReplyID)
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("finding selected reply: %w", err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return path, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *SQLStore) LatestConversationMessages() ([]model.Message, error) {
|
|
||||||
var latestMessages []model.Message
|
|
||||||
|
|
||||||
subQuery := s.db.Model(&model.Message{}).
|
|
||||||
Select("MAX(created_at) as max_created_at, conversation_id").
|
|
||||||
Group("conversation_id")
|
|
||||||
|
|
||||||
err := s.db.Model(&model.Message{}).
|
|
||||||
Joins("JOIN (?) as sub on messages.conversation_id = sub.conversation_id AND messages.created_at = sub.max_created_at", subQuery).
|
|
||||||
Order("created_at DESC").
|
|
||||||
Preload("Conversation").
|
|
||||||
Find(&latestMessages).Error
|
|
||||||
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
return latestMessages, nil
|
|
||||||
}
|
}
|
||||||
|
@ -11,9 +11,7 @@ import (
|
|||||||
toolutil "git.mlow.ca/mlow/lmcli/pkg/lmcli/tools/util"
|
toolutil "git.mlow.ca/mlow/lmcli/pkg/lmcli/tools/util"
|
||||||
)
|
)
|
||||||
|
|
||||||
const TREE_DESCRIPTION = `Retrieve a tree-like view of a directory's contents.
|
const TREE_DESCRIPTION = `Retrieve a tree view of a directory's contents.
|
||||||
|
|
||||||
Use these results for your own reference in completing your task, they do not need to be shown to the user.
|
|
||||||
|
|
||||||
Example result:
|
Example result:
|
||||||
{
|
{
|
||||||
@ -37,45 +35,48 @@ var DirTreeTool = model.Tool{
|
|||||||
Description: "If set, display the tree starting from this path relative to the current one.",
|
Description: "If set, display the tree starting from this path relative to the current one.",
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
Name: "depth",
|
Name: "max_depth",
|
||||||
Type: "integer",
|
Type: "integer",
|
||||||
Description: "Depth of directory recursion. Default 0. Use -1 for unlimited.",
|
Description: "Maximum depth of recursion. Default is unlimited.",
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
Impl: func(tool *model.Tool, args map[string]interface{}) (string, error) {
|
Impl: func(tool *model.Tool, args map[string]interface{}) (string, error) {
|
||||||
var relativeDir string
|
var relativeDir string
|
||||||
if tmp, ok := args["relative_path"]; ok {
|
tmp, ok := args["relative_dir"]
|
||||||
|
if ok {
|
||||||
relativeDir, ok = tmp.(string)
|
relativeDir, ok = tmp.(string)
|
||||||
if !ok {
|
if !ok {
|
||||||
return "", fmt.Errorf("expected string for relative_path, got %T", tmp)
|
return "", fmt.Errorf("Invalid relative_dir in function arguments: %v", tmp)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
var depth int = 0 // Default value if not provided
|
var maxDepth int = -1
|
||||||
if tmp, ok := args["depth"]; ok {
|
tmp, ok = args["max_depth"]
|
||||||
switch v := tmp.(type) {
|
if ok {
|
||||||
case float64:
|
maxDepth, ok = tmp.(int)
|
||||||
depth = int(v)
|
if !ok {
|
||||||
case string:
|
if tmps, ok := tmp.(string); ok {
|
||||||
var err error
|
tmpi, err := strconv.Atoi(tmps)
|
||||||
if depth, err = strconv.Atoi(v); err != nil {
|
maxDepth = tmpi
|
||||||
return "", fmt.Errorf("invalid `depth` value, expected integer but got string that cannot convert: %v", tmp)
|
if err != nil {
|
||||||
|
return "", fmt.Errorf("Invalid max_depth in function arguments: %v", tmp)
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
return "", fmt.Errorf("Invalid max_depth in function arguments: %v", tmp)
|
||||||
}
|
}
|
||||||
default:
|
|
||||||
return "", fmt.Errorf("expected int or string for max_depth, got %T", tmp)
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
result := tree(relativeDir, depth)
|
result := tree(relativeDir, maxDepth)
|
||||||
ret, err := result.ToJson()
|
ret, err := result.ToJson()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return "", fmt.Errorf("could not serialize result: %v", err)
|
return "", fmt.Errorf("Could not serialize result: %v", err)
|
||||||
}
|
}
|
||||||
return ret, nil
|
return ret, nil
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
func tree(path string, depth int) model.CallResult {
|
func tree(path string, maxDepth int) model.CallResult {
|
||||||
if path == "" {
|
if path == "" {
|
||||||
path = "."
|
path = "."
|
||||||
}
|
}
|
||||||
@ -86,7 +87,7 @@ func tree(path string, depth int) model.CallResult {
|
|||||||
|
|
||||||
var treeOutput strings.Builder
|
var treeOutput strings.Builder
|
||||||
treeOutput.WriteString(path + "\n")
|
treeOutput.WriteString(path + "\n")
|
||||||
err := buildTree(&treeOutput, path, "", depth)
|
err := buildTree(&treeOutput, path, "", maxDepth)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return model.CallResult{
|
return model.CallResult{
|
||||||
Message: err.Error(),
|
Message: err.Error(),
|
||||||
@ -96,7 +97,7 @@ func tree(path string, depth int) model.CallResult {
|
|||||||
return model.CallResult{Result: treeOutput.String()}
|
return model.CallResult{Result: treeOutput.String()}
|
||||||
}
|
}
|
||||||
|
|
||||||
func buildTree(output *strings.Builder, path string, prefix string, depth int) error {
|
func buildTree(output *strings.Builder, path string, prefix string, maxDepth int) error {
|
||||||
files, err := os.ReadDir(path)
|
files, err := os.ReadDir(path)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
@ -123,14 +124,14 @@ func buildTree(output *strings.Builder, path string, prefix string, depth int) e
|
|||||||
output.WriteString(prefix + branch + file.Name())
|
output.WriteString(prefix + branch + file.Name())
|
||||||
if file.IsDir() {
|
if file.IsDir() {
|
||||||
output.WriteString("/\n")
|
output.WriteString("/\n")
|
||||||
if depth != 0 {
|
if maxDepth != 0 {
|
||||||
var nextPrefix string
|
var nextPrefix string
|
||||||
if isLast {
|
if isLast {
|
||||||
nextPrefix = prefix + " "
|
nextPrefix = prefix + " "
|
||||||
} else {
|
} else {
|
||||||
nextPrefix = prefix + "│ "
|
nextPrefix = prefix + "│ "
|
||||||
}
|
}
|
||||||
buildTree(output, filepath.Join(path, file.Name()), nextPrefix, depth-1)
|
buildTree(output, filepath.Join(path, file.Name()), nextPrefix, maxDepth-1)
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
output.WriteString(sizeStr + "\n")
|
output.WriteString(sizeStr + "\n")
|
||||||
@ -139,3 +140,4 @@ func buildTree(output *strings.Builder, path string, prefix string, depth int) e
|
|||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -9,9 +9,7 @@ import (
|
|||||||
toolutil "git.mlow.ca/mlow/lmcli/pkg/lmcli/tools/util"
|
toolutil "git.mlow.ca/mlow/lmcli/pkg/lmcli/tools/util"
|
||||||
)
|
)
|
||||||
|
|
||||||
const READ_FILE_DESCRIPTION = `Retrieve the contents of a text file relative to the current working directory.
|
const READ_FILE_DESCRIPTION = `Read the contents of a text file relative to the current working directory.
|
||||||
|
|
||||||
Use the file contents for your own reference in completing your task, they do not need to be shown to the user.
|
|
||||||
|
|
||||||
Each line of the returned content is prefixed with its line number and a tab (\t).
|
Each line of the returned content is prefixed with its line number and a tab (\t).
|
||||||
|
|
||||||
|
@ -38,8 +38,8 @@ func ExecuteToolCalls(toolCalls []model.ToolCall, toolBag []model.Tool) ([]model
|
|||||||
|
|
||||||
toolResult := model.ToolResult{
|
toolResult := model.ToolResult{
|
||||||
ToolCallID: toolCall.ID,
|
ToolCallID: toolCall.ID,
|
||||||
ToolName: toolCall.Name,
|
ToolName: toolCall.Name,
|
||||||
Result: result,
|
Result: result,
|
||||||
}
|
}
|
||||||
|
|
||||||
toolResults = append(toolResults, toolResult)
|
toolResults = append(toolResults, toolResult)
|
||||||
|
136
pkg/tui/chat.go
136
pkg/tui/chat.go
@ -66,10 +66,6 @@ type chatModel struct {
|
|||||||
replyChunkChan chan string
|
replyChunkChan chan string
|
||||||
persistence bool // whether we will save new messages in the conversation
|
persistence bool // whether we will save new messages in the conversation
|
||||||
|
|
||||||
tokenCount uint
|
|
||||||
startTime time.Time
|
|
||||||
elapsed time.Duration
|
|
||||||
|
|
||||||
// ui state
|
// ui state
|
||||||
focus focusState
|
focus focusState
|
||||||
wrap bool // whether message content is wrapped to viewport width
|
wrap bool // whether message content is wrapped to viewport width
|
||||||
@ -286,9 +282,6 @@ func (m chatModel) Update(msg tea.Msg) (chatModel, tea.Cmd) {
|
|||||||
}
|
}
|
||||||
m.updateContent()
|
m.updateContent()
|
||||||
cmds = append(cmds, m.waitForChunk()) // wait for the next chunk
|
cmds = append(cmds, m.waitForChunk()) // wait for the next chunk
|
||||||
|
|
||||||
m.tokenCount++
|
|
||||||
m.elapsed = time.Now().Sub(m.startTime)
|
|
||||||
case msgAssistantReply:
|
case msgAssistantReply:
|
||||||
// the last reply that was being worked on is finished
|
// the last reply that was being worked on is finished
|
||||||
reply := models.Message(msg)
|
reply := models.Message(msg)
|
||||||
@ -307,9 +300,14 @@ func (m chatModel) Update(msg tea.Msg) (chatModel, tea.Cmd) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
if m.persistence {
|
if m.persistence {
|
||||||
err := m.persistConversation()
|
var err error
|
||||||
|
if m.conversation.ID == 0 {
|
||||||
|
err = m.ctx.Store.SaveConversation(m.conversation)
|
||||||
|
}
|
||||||
if err != nil {
|
if err != nil {
|
||||||
cmds = append(cmds, wrapError(err))
|
cmds = append(cmds, wrapError(err))
|
||||||
|
} else {
|
||||||
|
cmds = append(cmds, m.persistConversation())
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -336,7 +334,7 @@ func (m chatModel) Update(msg tea.Msg) (chatModel, tea.Cmd) {
|
|||||||
title := string(msg)
|
title := string(msg)
|
||||||
m.conversation.Title = title
|
m.conversation.Title = title
|
||||||
if m.persistence {
|
if m.persistence {
|
||||||
err := m.ctx.Store.UpdateConversation(m.conversation)
|
err := m.ctx.Store.SaveConversation(m.conversation)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
cmds = append(cmds, wrapError(err))
|
cmds = append(cmds, wrapError(err))
|
||||||
}
|
}
|
||||||
@ -464,8 +462,8 @@ func (m *chatModel) handleInputKey(msg tea.KeyMsg) (bool, tea.Cmd) {
|
|||||||
m.input.Blur()
|
m.input.Blur()
|
||||||
return true, nil
|
return true, nil
|
||||||
case "ctrl+s":
|
case "ctrl+s":
|
||||||
input := strings.TrimSpace(m.input.Value())
|
userInput := strings.TrimSpace(m.input.Value())
|
||||||
if input == "" {
|
if strings.TrimSpace(userInput) == "" {
|
||||||
return true, nil
|
return true, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -473,20 +471,36 @@ func (m *chatModel) handleInputKey(msg tea.KeyMsg) (bool, tea.Cmd) {
|
|||||||
return true, wrapError(fmt.Errorf("Can't reply to a user message"))
|
return true, wrapError(fmt.Errorf("Can't reply to a user message"))
|
||||||
}
|
}
|
||||||
|
|
||||||
m.addMessage(models.Message{
|
reply := models.Message{
|
||||||
Role: models.MessageRoleUser,
|
Role: models.MessageRoleUser,
|
||||||
Content: input,
|
Content: userInput,
|
||||||
})
|
}
|
||||||
|
|
||||||
m.input.SetValue("")
|
|
||||||
|
|
||||||
if m.persistence {
|
if m.persistence {
|
||||||
err := m.persistConversation()
|
var err error
|
||||||
|
if m.conversation.ID == 0 {
|
||||||
|
err = m.ctx.Store.SaveConversation(m.conversation)
|
||||||
|
}
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return true, wrapError(err)
|
return true, wrapError(err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// ensure all messages up to the one we're about to add are persisted
|
||||||
|
cmd := m.persistConversation()
|
||||||
|
if cmd != nil {
|
||||||
|
return true, cmd
|
||||||
|
}
|
||||||
|
|
||||||
|
savedReply, err := m.ctx.Store.AddReply(m.conversation, reply)
|
||||||
|
if err != nil {
|
||||||
|
return true, wrapError(err)
|
||||||
|
}
|
||||||
|
reply = *savedReply
|
||||||
}
|
}
|
||||||
|
|
||||||
|
m.input.SetValue("")
|
||||||
|
m.addMessage(reply)
|
||||||
|
|
||||||
m.updateContent()
|
m.updateContent()
|
||||||
m.content.GotoBottom()
|
m.content.GotoBottom()
|
||||||
return true, m.promptLLM()
|
return true, m.promptLLM()
|
||||||
@ -679,16 +693,10 @@ func (m *chatModel) footerView() string {
|
|||||||
saving,
|
saving,
|
||||||
segmentStyle.Render(status),
|
segmentStyle.Render(status),
|
||||||
}
|
}
|
||||||
rightSegments := []string{}
|
rightSegments := []string{
|
||||||
|
segmentStyle.Render(fmt.Sprintf("Model: %s", *m.ctx.Config.Defaults.Model)),
|
||||||
if m.elapsed > 0 && m.tokenCount > 0 {
|
|
||||||
throughput := fmt.Sprintf("%.0f t/sec", float64(m.tokenCount)/m.elapsed.Seconds())
|
|
||||||
rightSegments = append(rightSegments, segmentStyle.Render(throughput))
|
|
||||||
}
|
}
|
||||||
|
|
||||||
model := fmt.Sprintf("Model: %s", *m.ctx.Config.Defaults.Model)
|
|
||||||
rightSegments = append(rightSegments, segmentStyle.Render(model))
|
|
||||||
|
|
||||||
left := strings.Join(leftSegments, segmentSeparator)
|
left := strings.Join(leftSegments, segmentSeparator)
|
||||||
right := strings.Join(rightSegments, segmentSeparator)
|
right := strings.Join(rightSegments, segmentSeparator)
|
||||||
|
|
||||||
@ -762,7 +770,7 @@ func (m *chatModel) loadConversation(shortname string) tea.Cmd {
|
|||||||
|
|
||||||
func (m *chatModel) loadMessages(c *models.Conversation) tea.Cmd {
|
func (m *chatModel) loadMessages(c *models.Conversation) tea.Cmd {
|
||||||
return func() tea.Msg {
|
return func() tea.Msg {
|
||||||
messages, err := m.ctx.Store.PathToLeaf(c.SelectedRoot)
|
messages, err := m.ctx.Store.Messages(c)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return msgError(fmt.Errorf("Could not load conversation messages: %v\n", err))
|
return msgError(fmt.Errorf("Could not load conversation messages: %v\n", err))
|
||||||
}
|
}
|
||||||
@ -770,48 +778,62 @@ func (m *chatModel) loadMessages(c *models.Conversation) tea.Cmd {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *chatModel) persistConversation() error {
|
func (m *chatModel) persistConversation() tea.Cmd {
|
||||||
if m.conversation.ID == 0 {
|
existingMessages, err := m.ctx.Store.Messages(m.conversation)
|
||||||
// Start a new conversation with all messages so far
|
if err != nil {
|
||||||
c, messages, err := m.ctx.Store.StartConversation(m.messages...)
|
return wrapError(fmt.Errorf("Could not retrieve existing conversation messages while trying to save: %v", err))
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
m.conversation = c
|
|
||||||
m.messages = messages
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// else, we'll handle updating an existing conversation's messages
|
existingById := make(map[uint]*models.Message, len(existingMessages))
|
||||||
for i := 0; i < len(m.messages); i++ {
|
for _, msg := range existingMessages {
|
||||||
if m.messages[i].ID > 0 {
|
existingById[msg.ID] = &msg
|
||||||
// message has an ID, update its contents
|
}
|
||||||
// TODO: check for content/tool equality before updating?
|
|
||||||
err := m.ctx.Store.UpdateMessage(&m.messages[i])
|
currentById := make(map[uint]*models.Message, len(m.messages))
|
||||||
|
for _, msg := range m.messages {
|
||||||
|
currentById[msg.ID] = &msg
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, msg := range existingMessages {
|
||||||
|
_, ok := currentById[msg.ID]
|
||||||
|
if !ok {
|
||||||
|
err := m.ctx.Store.DeleteMessage(&msg)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return wrapError(fmt.Errorf("Failed to remove messages: %v", err))
|
||||||
}
|
}
|
||||||
} else if i > 0 {
|
}
|
||||||
// messages is new, so add it as a reply to previous message
|
}
|
||||||
saved, err := m.ctx.Store.Reply(&m.messages[i-1], m.messages[i])
|
|
||||||
if err != nil {
|
for i, msg := range m.messages {
|
||||||
return err
|
if msg.ID > 0 {
|
||||||
|
exist, ok := existingById[msg.ID]
|
||||||
|
if ok {
|
||||||
|
if msg.Content == exist.Content {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
// update message when contents don't match that of store
|
||||||
|
err := m.ctx.Store.UpdateMessage(&msg)
|
||||||
|
if err != nil {
|
||||||
|
return wrapError(err)
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
// this would be quite odd... and I'm not sure how to handle
|
||||||
|
// it at the time of writing this
|
||||||
}
|
}
|
||||||
m.messages[i] = saved[0]
|
|
||||||
} else {
|
} else {
|
||||||
// message has no id and no previous messages to add it to
|
newMessage, err := m.ctx.Store.AddReply(m.conversation, msg)
|
||||||
// this shouldn't happen?
|
if err != nil {
|
||||||
return fmt.Errorf("Error: no messages to reply to")
|
return wrapError(err)
|
||||||
|
}
|
||||||
|
m.setMessage(i, *newMessage)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *chatModel) generateConversationTitle() tea.Cmd {
|
func (m *chatModel) generateConversationTitle() tea.Cmd {
|
||||||
return func() tea.Msg {
|
return func() tea.Msg {
|
||||||
title, err := cmdutil.GenerateTitle(m.ctx, m.messages)
|
title, err := cmdutil.GenerateTitle(m.ctx, m.conversation)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return msgError(err)
|
return msgError(err)
|
||||||
}
|
}
|
||||||
@ -835,10 +857,6 @@ func (m *chatModel) promptLLM() tea.Cmd {
|
|||||||
m.waitingForReply = true
|
m.waitingForReply = true
|
||||||
m.status = "Press ctrl+c to cancel"
|
m.status = "Press ctrl+c to cancel"
|
||||||
|
|
||||||
m.tokenCount = 0
|
|
||||||
m.startTime = time.Now()
|
|
||||||
m.elapsed = 0
|
|
||||||
|
|
||||||
return func() tea.Msg {
|
return func() tea.Msg {
|
||||||
completionProvider, err := m.ctx.GetCompletionProvider(*m.ctx.Config.Defaults.Model)
|
completionProvider, err := m.ctx.GetCompletionProvider(*m.ctx.Config.Defaults.Model)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
@ -2,6 +2,7 @@ package tui
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"slices"
|
||||||
"strings"
|
"strings"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
@ -114,7 +115,7 @@ func (m *conversationsModel) handleResize(width, height int) {
|
|||||||
func (m conversationsModel) Update(msg tea.Msg) (conversationsModel, tea.Cmd) {
|
func (m conversationsModel) Update(msg tea.Msg) (conversationsModel, tea.Cmd) {
|
||||||
var cmds []tea.Cmd
|
var cmds []tea.Cmd
|
||||||
switch msg := msg.(type) {
|
switch msg := msg.(type) {
|
||||||
case msgStateEnter:
|
case msgStateChange:
|
||||||
cmds = append(cmds, m.loadConversations())
|
cmds = append(cmds, m.loadConversations())
|
||||||
m.content.SetContent(m.renderConversationList())
|
m.content.SetContent(m.renderConversationList())
|
||||||
case tea.WindowSizeMsg:
|
case tea.WindowSizeMsg:
|
||||||
@ -144,17 +145,25 @@ func (m conversationsModel) Update(msg tea.Msg) (conversationsModel, tea.Cmd) {
|
|||||||
|
|
||||||
func (m *conversationsModel) loadConversations() tea.Cmd {
|
func (m *conversationsModel) loadConversations() tea.Cmd {
|
||||||
return func() tea.Msg {
|
return func() tea.Msg {
|
||||||
messages, err := m.ctx.Store.LatestConversationMessages()
|
conversations, err := m.ctx.Store.Conversations()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return msgError(fmt.Errorf("Could not load conversations: %v", err))
|
return msgError(fmt.Errorf("Could not load conversations: %v", err))
|
||||||
}
|
}
|
||||||
|
|
||||||
loaded := make([]loadedConversation, len(messages))
|
loaded := make([]loadedConversation, len(conversations))
|
||||||
for i, m := range messages {
|
for i, c := range conversations {
|
||||||
loaded[i].lastReply = m
|
lastMessage, err := m.ctx.Store.LastMessage(&c)
|
||||||
loaded[i].conv = m.Conversation
|
if err != nil {
|
||||||
|
return msgError(err)
|
||||||
|
}
|
||||||
|
loaded[i].conv = c
|
||||||
|
loaded[i].lastReply = *lastMessage
|
||||||
}
|
}
|
||||||
|
|
||||||
|
slices.SortFunc(loaded, func(a, b loadedConversation) int {
|
||||||
|
return b.lastReply.CreatedAt.Compare(a.lastReply.CreatedAt)
|
||||||
|
})
|
||||||
|
|
||||||
return msgConversationsLoaded(loaded)
|
return msgConversationsLoaded(loaded)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -147,16 +147,16 @@ func SetStructDefaults(data interface{}) bool {
|
|||||||
case reflect.String:
|
case reflect.String:
|
||||||
defaultValue := defaultTag
|
defaultValue := defaultTag
|
||||||
field.Set(reflect.ValueOf(&defaultValue))
|
field.Set(reflect.ValueOf(&defaultValue))
|
||||||
case reflect.Uint, reflect.Uint32, reflect.Uint64:
|
|
||||||
intValue, _ := strconv.ParseUint(defaultTag, 10, e.Bits())
|
|
||||||
field.Set(reflect.New(e))
|
|
||||||
field.Elem().SetUint(intValue)
|
|
||||||
case reflect.Int, reflect.Int32, reflect.Int64:
|
case reflect.Int, reflect.Int32, reflect.Int64:
|
||||||
intValue, _ := strconv.ParseInt(defaultTag, 10, e.Bits())
|
intValue, _ := strconv.ParseInt(defaultTag, 10, 64)
|
||||||
field.Set(reflect.New(e))
|
field.Set(reflect.New(e))
|
||||||
field.Elem().SetInt(intValue)
|
field.Elem().SetInt(intValue)
|
||||||
case reflect.Float32, reflect.Float64:
|
case reflect.Float32:
|
||||||
floatValue, _ := strconv.ParseFloat(defaultTag, e.Bits())
|
floatValue, _ := strconv.ParseFloat(defaultTag, 32)
|
||||||
|
field.Set(reflect.New(e))
|
||||||
|
field.Elem().SetFloat(floatValue)
|
||||||
|
case reflect.Float64:
|
||||||
|
floatValue, _ := strconv.ParseFloat(defaultTag, 64)
|
||||||
field.Set(reflect.New(e))
|
field.Set(reflect.New(e))
|
||||||
field.Elem().SetFloat(floatValue)
|
field.Elem().SetFloat(floatValue)
|
||||||
case reflect.Bool:
|
case reflect.Bool:
|
||||||
|
Loading…
Reference in New Issue
Block a user