Compare commits
14 Commits
aeeb7bb7f7
...
db465f1bf0
Author | SHA1 | Date | |
---|---|---|---|
db465f1bf0 | |||
f6e55f6bff | |||
dc1edf8c3e | |||
62d98289e8 | |||
b82f3019f0 | |||
1bd953676d | |||
a291e7b42c | |||
1b8d04c96d | |||
cbcd3b1ba9 | |||
75bf9f6125 | |||
9ff4322995 | |||
54f5a3c209 | |||
86bdc733bf | |||
60394de620 |
@ -3,6 +3,7 @@ package cmd
|
|||||||
import (
|
import (
|
||||||
"fmt"
|
"fmt"
|
||||||
|
|
||||||
|
cmdutil "git.mlow.ca/mlow/lmcli/pkg/cmd/util"
|
||||||
"git.mlow.ca/mlow/lmcli/pkg/lmcli"
|
"git.mlow.ca/mlow/lmcli/pkg/lmcli"
|
||||||
"git.mlow.ca/mlow/lmcli/pkg/tui"
|
"git.mlow.ca/mlow/lmcli/pkg/tui"
|
||||||
"github.com/spf13/cobra"
|
"github.com/spf13/cobra"
|
||||||
@ -14,11 +15,16 @@ func ChatCmd(ctx *lmcli.Context) *cobra.Command {
|
|||||||
Short: "Open the chat interface",
|
Short: "Open the chat interface",
|
||||||
Long: `Open the chat interface, optionally on a given conversation.`,
|
Long: `Open the chat interface, optionally on a given conversation.`,
|
||||||
RunE: func(cmd *cobra.Command, args []string) error {
|
RunE: func(cmd *cobra.Command, args []string) error {
|
||||||
// TODO: implement jump-to-conversation logic
|
|
||||||
shortname := ""
|
shortname := ""
|
||||||
if len(args) == 1 {
|
if len(args) == 1 {
|
||||||
shortname = args[0]
|
shortname = args[0]
|
||||||
}
|
}
|
||||||
|
if shortname != ""{
|
||||||
|
_, err := cmdutil.LookupConversationE(ctx, shortname)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
err := tui.Launch(ctx, shortname)
|
err := tui.Launch(ctx, shortname)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("Error fetching LLM response: %v", err)
|
return fmt.Errorf("Error fetching LLM response: %v", err)
|
||||||
|
@ -5,7 +5,6 @@ import (
|
|||||||
|
|
||||||
cmdutil "git.mlow.ca/mlow/lmcli/pkg/cmd/util"
|
cmdutil "git.mlow.ca/mlow/lmcli/pkg/cmd/util"
|
||||||
"git.mlow.ca/mlow/lmcli/pkg/lmcli"
|
"git.mlow.ca/mlow/lmcli/pkg/lmcli"
|
||||||
"git.mlow.ca/mlow/lmcli/pkg/lmcli/model"
|
|
||||||
"github.com/spf13/cobra"
|
"github.com/spf13/cobra"
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -28,36 +27,12 @@ func CloneCmd(ctx *lmcli.Context) *cobra.Command {
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
messagesToCopy, err := ctx.Store.Messages(toClone)
|
clone, messageCnt, err := ctx.Store.CloneConversation(*toClone)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("Could not retrieve messages for conversation: %s", toClone.ShortName.String)
|
return fmt.Errorf("Failed to clone conversation: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
clone := &model.Conversation{
|
fmt.Printf("Cloned %d messages to: %s - %s\n", messageCnt, clone.ShortName.String, clone.Title)
|
||||||
Title: toClone.Title + " - Clone",
|
|
||||||
}
|
|
||||||
if err := ctx.Store.SaveConversation(clone); err != nil {
|
|
||||||
return fmt.Errorf("Cloud not create clone: %s", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
var errors []error
|
|
||||||
messageCnt := 0
|
|
||||||
for _, message := range messagesToCopy {
|
|
||||||
newMessage := message
|
|
||||||
newMessage.ConversationID = clone.ID
|
|
||||||
newMessage.ID = 0
|
|
||||||
if err := ctx.Store.SaveMessage(&newMessage); err != nil {
|
|
||||||
errors = append(errors, err)
|
|
||||||
} else {
|
|
||||||
messageCnt++
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if len(errors) > 0 {
|
|
||||||
return fmt.Errorf("Messages failed to be cloned: %v", errors)
|
|
||||||
}
|
|
||||||
|
|
||||||
fmt.Printf("Cloned %d messages to: %s\n", messageCnt, clone.Title)
|
|
||||||
return nil
|
return nil
|
||||||
},
|
},
|
||||||
ValidArgsFunction: func(cmd *cobra.Command, args []string, toComplete string) ([]string, cobra.ShellCompDirective) {
|
ValidArgsFunction: func(cmd *cobra.Command, args []string, toComplete string) ([]string, cobra.ShellCompDirective) {
|
||||||
|
@ -26,7 +26,7 @@ func ContinueCmd(ctx *lmcli.Context) *cobra.Command {
|
|||||||
shortName := args[0]
|
shortName := args[0]
|
||||||
conversation := cmdutil.LookupConversation(ctx, shortName)
|
conversation := cmdutil.LookupConversation(ctx, shortName)
|
||||||
|
|
||||||
messages, err := ctx.Store.Messages(conversation)
|
messages, err := ctx.Store.PathToLeaf(conversation.SelectedRoot)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("could not retrieve conversation messages: %v", err)
|
return fmt.Errorf("could not retrieve conversation messages: %v", err)
|
||||||
}
|
}
|
||||||
|
@ -24,7 +24,7 @@ func EditCmd(ctx *lmcli.Context) *cobra.Command {
|
|||||||
shortName := args[0]
|
shortName := args[0]
|
||||||
conversation := cmdutil.LookupConversation(ctx, shortName)
|
conversation := cmdutil.LookupConversation(ctx, shortName)
|
||||||
|
|
||||||
messages, err := ctx.Store.Messages(conversation)
|
messages, err := ctx.Store.PathToLeaf(conversation.SelectedRoot)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("Could not retrieve messages for conversation: %s", conversation.Title)
|
return fmt.Errorf("Could not retrieve messages for conversation: %s", conversation.Title)
|
||||||
}
|
}
|
||||||
@ -39,21 +39,7 @@ func EditCmd(ctx *lmcli.Context) *cobra.Command {
|
|||||||
}
|
}
|
||||||
|
|
||||||
desiredIdx := len(messages) - 1 - offset
|
desiredIdx := len(messages) - 1 - offset
|
||||||
|
toEdit := messages[desiredIdx]
|
||||||
// walk backwards through the conversation deleting messages until and
|
|
||||||
// including the last user message
|
|
||||||
toRemove := []model.Message{}
|
|
||||||
var toEdit *model.Message
|
|
||||||
for i := len(messages) - 1; i >= 0; i-- {
|
|
||||||
if i == desiredIdx {
|
|
||||||
toEdit = &messages[i]
|
|
||||||
}
|
|
||||||
toRemove = append(toRemove, messages[i])
|
|
||||||
messages = messages[:i]
|
|
||||||
if toEdit != nil {
|
|
||||||
break
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
newContents := inputFromArgsOrEditor(args[1:], "# Save when finished editing\n", toEdit.Content)
|
newContents := inputFromArgsOrEditor(args[1:], "# Save when finished editing\n", toEdit.Content)
|
||||||
switch newContents {
|
switch newContents {
|
||||||
@ -63,26 +49,17 @@ func EditCmd(ctx *lmcli.Context) *cobra.Command {
|
|||||||
return fmt.Errorf("No message was provided.")
|
return fmt.Errorf("No message was provided.")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
toEdit.Content = newContents
|
||||||
|
|
||||||
role, _ := cmd.Flags().GetString("role")
|
role, _ := cmd.Flags().GetString("role")
|
||||||
if role == "" {
|
if role != "" {
|
||||||
role = string(toEdit.Role)
|
if role != string(model.MessageRoleUser) && role != string(model.MessageRoleAssistant) {
|
||||||
} else if role != string(model.MessageRoleUser) && role != string(model.MessageRoleAssistant) {
|
|
||||||
return fmt.Errorf("Invalid role specified. Please use 'user' or 'assistant'.")
|
return fmt.Errorf("Invalid role specified. Please use 'user' or 'assistant'.")
|
||||||
}
|
}
|
||||||
|
toEdit.Role = model.MessageRole(role)
|
||||||
for _, message := range toRemove {
|
|
||||||
err = ctx.Store.DeleteMessage(&message)
|
|
||||||
if err != nil {
|
|
||||||
lmcli.Warn("Could not delete message: %v\n", err)
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
cmdutil.HandleConversationReply(ctx, conversation, true, model.Message{
|
return ctx.Store.UpdateMessage(&toEdit)
|
||||||
ConversationID: conversation.ID,
|
|
||||||
Role: model.MessageRole(role),
|
|
||||||
Content: newContents,
|
|
||||||
})
|
|
||||||
return nil
|
|
||||||
},
|
},
|
||||||
ValidArgsFunction: func(cmd *cobra.Command, args []string, toComplete string) ([]string, cobra.ShellCompDirective) {
|
ValidArgsFunction: func(cmd *cobra.Command, args []string, toComplete string) ([]string, cobra.ShellCompDirective) {
|
||||||
compMode := cobra.ShellCompDirectiveNoFileComp
|
compMode := cobra.ShellCompDirectiveNoFileComp
|
||||||
|
@ -2,7 +2,6 @@ package cmd
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"fmt"
|
"fmt"
|
||||||
"slices"
|
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"git.mlow.ca/mlow/lmcli/pkg/lmcli"
|
"git.mlow.ca/mlow/lmcli/pkg/lmcli"
|
||||||
@ -21,7 +20,7 @@ func ListCmd(ctx *lmcli.Context) *cobra.Command {
|
|||||||
Short: "List conversations",
|
Short: "List conversations",
|
||||||
Long: `List conversations in order of recent activity`,
|
Long: `List conversations in order of recent activity`,
|
||||||
RunE: func(cmd *cobra.Command, args []string) error {
|
RunE: func(cmd *cobra.Command, args []string) error {
|
||||||
conversations, err := ctx.Store.Conversations()
|
messages, err := ctx.Store.LatestConversationMessages()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("Could not fetch conversations: %v", err)
|
return fmt.Errorf("Could not fetch conversations: %v", err)
|
||||||
}
|
}
|
||||||
@ -58,13 +57,8 @@ func ListCmd(ctx *lmcli.Context) *cobra.Command {
|
|||||||
|
|
||||||
all, _ := cmd.Flags().GetBool("all")
|
all, _ := cmd.Flags().GetBool("all")
|
||||||
|
|
||||||
for _, conversation := range conversations {
|
for _, message := range messages {
|
||||||
lastMessage, err := ctx.Store.LastMessage(&conversation)
|
messageAge := now.Sub(message.CreatedAt)
|
||||||
if lastMessage == nil || err != nil {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
messageAge := now.Sub(lastMessage.CreatedAt)
|
|
||||||
|
|
||||||
var category string
|
var category string
|
||||||
for _, c := range categories {
|
for _, c := range categories {
|
||||||
@ -76,9 +70,9 @@ func ListCmd(ctx *lmcli.Context) *cobra.Command {
|
|||||||
|
|
||||||
formatted := fmt.Sprintf(
|
formatted := fmt.Sprintf(
|
||||||
"%s - %s - %s",
|
"%s - %s - %s",
|
||||||
conversation.ShortName.String,
|
message.Conversation.ShortName.String,
|
||||||
util.HumanTimeElapsedSince(messageAge),
|
util.HumanTimeElapsedSince(messageAge),
|
||||||
conversation.Title,
|
message.Conversation.Title,
|
||||||
)
|
)
|
||||||
|
|
||||||
categorized[category] = append(
|
categorized[category] = append(
|
||||||
@ -96,14 +90,10 @@ func ListCmd(ctx *lmcli.Context) *cobra.Command {
|
|||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
slices.SortFunc(conversationLines, func(a, b ConversationLine) int {
|
|
||||||
return int(a.timeSinceReply - b.timeSinceReply)
|
|
||||||
})
|
|
||||||
|
|
||||||
fmt.Printf("%s:\n", category.name)
|
fmt.Printf("%s:\n", category.name)
|
||||||
for _, conv := range conversationLines {
|
for _, conv := range conversationLines {
|
||||||
if conversationsPrinted >= count && !all {
|
if conversationsPrinted >= count && !all {
|
||||||
fmt.Printf("%d remaining message(s), use --all to view.\n", len(conversations)-conversationsPrinted)
|
fmt.Printf("%d remaining conversation(s), use --all to view.\n", len(messages)-conversationsPrinted)
|
||||||
break outer
|
break outer
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -15,42 +15,43 @@ func NewCmd(ctx *lmcli.Context) *cobra.Command {
|
|||||||
Short: "Start a new conversation",
|
Short: "Start a new conversation",
|
||||||
Long: `Start a new conversation with the Large Language Model.`,
|
Long: `Start a new conversation with the Large Language Model.`,
|
||||||
RunE: func(cmd *cobra.Command, args []string) error {
|
RunE: func(cmd *cobra.Command, args []string) error {
|
||||||
messageContents := inputFromArgsOrEditor(args, "# What would you like to say?\n", "")
|
input := inputFromArgsOrEditor(args, "# Start a new conversation below\n", "")
|
||||||
if messageContents == "" {
|
if input == "" {
|
||||||
return fmt.Errorf("No message was provided.")
|
return fmt.Errorf("No message was provided.")
|
||||||
}
|
}
|
||||||
|
|
||||||
conversation := &model.Conversation{}
|
var messages []model.Message
|
||||||
err := ctx.Store.SaveConversation(conversation)
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("Could not save new conversation: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
messages := []model.Message{
|
// TODO: probably just make this part of the conversation
|
||||||
{
|
system := ctx.GetSystemPrompt()
|
||||||
ConversationID: conversation.ID,
|
if system != "" {
|
||||||
|
messages = append(messages, model.Message{
|
||||||
Role: model.MessageRoleSystem,
|
Role: model.MessageRoleSystem,
|
||||||
Content: ctx.GetSystemPrompt(),
|
Content: system,
|
||||||
},
|
})
|
||||||
{
|
|
||||||
ConversationID: conversation.ID,
|
|
||||||
Role: model.MessageRoleUser,
|
|
||||||
Content: messageContents,
|
|
||||||
},
|
|
||||||
}
|
}
|
||||||
|
|
||||||
cmdutil.HandleConversationReply(ctx, conversation, true, messages...)
|
messages = append(messages, model.Message{
|
||||||
|
Role: model.MessageRoleUser,
|
||||||
|
Content: input,
|
||||||
|
})
|
||||||
|
|
||||||
title, err := cmdutil.GenerateTitle(ctx, conversation)
|
conversation, messages, err := ctx.Store.StartConversation(messages...)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
lmcli.Warn("Could not generate title for conversation: %v\n", err)
|
return fmt.Errorf("Could not start a new conversation: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
cmdutil.HandleReply(ctx, &messages[len(messages)-1], true)
|
||||||
|
|
||||||
|
title, err := cmdutil.GenerateTitle(ctx, messages)
|
||||||
|
if err != nil {
|
||||||
|
lmcli.Warn("Could not generate title for conversation %s: %v\n", conversation.ShortName.String, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
conversation.Title = title
|
conversation.Title = title
|
||||||
|
err = ctx.Store.UpdateConversation(conversation)
|
||||||
err = ctx.Store.SaveConversation(conversation)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
lmcli.Warn("Could not save conversation after generating title: %v\n", err)
|
lmcli.Warn("Could not save conversation title: %v\n", err)
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
},
|
},
|
||||||
|
@ -15,22 +15,27 @@ func PromptCmd(ctx *lmcli.Context) *cobra.Command {
|
|||||||
Short: "Do a one-shot prompt",
|
Short: "Do a one-shot prompt",
|
||||||
Long: `Prompt the Large Language Model and get a response.`,
|
Long: `Prompt the Large Language Model and get a response.`,
|
||||||
RunE: func(cmd *cobra.Command, args []string) error {
|
RunE: func(cmd *cobra.Command, args []string) error {
|
||||||
message := inputFromArgsOrEditor(args, "# What would you like to say?\n", "")
|
input := inputFromArgsOrEditor(args, "# Write your prompt below\n", "")
|
||||||
if message == "" {
|
if input == "" {
|
||||||
return fmt.Errorf("No message was provided.")
|
return fmt.Errorf("No message was provided.")
|
||||||
}
|
}
|
||||||
|
|
||||||
messages := []model.Message{
|
var messages []model.Message
|
||||||
{
|
|
||||||
|
// TODO: stop supplying system prompt as a message
|
||||||
|
system := ctx.GetSystemPrompt()
|
||||||
|
if system != "" {
|
||||||
|
messages = append(messages, model.Message{
|
||||||
Role: model.MessageRoleSystem,
|
Role: model.MessageRoleSystem,
|
||||||
Content: ctx.GetSystemPrompt(),
|
Content: system,
|
||||||
},
|
})
|
||||||
{
|
|
||||||
Role: model.MessageRoleUser,
|
|
||||||
Content: message,
|
|
||||||
},
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
messages = append(messages, model.Message{
|
||||||
|
Role: model.MessageRoleUser,
|
||||||
|
Content: input,
|
||||||
|
})
|
||||||
|
|
||||||
_, err := cmdutil.Prompt(ctx, messages, nil)
|
_, err := cmdutil.Prompt(ctx, messages, nil)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("Error fetching LLM response: %v", err)
|
return fmt.Errorf("Error fetching LLM response: %v", err)
|
||||||
|
@ -24,12 +24,17 @@ func RenameCmd(ctx *lmcli.Context) *cobra.Command {
|
|||||||
RunE: func(cmd *cobra.Command, args []string) error {
|
RunE: func(cmd *cobra.Command, args []string) error {
|
||||||
shortName := args[0]
|
shortName := args[0]
|
||||||
conversation := cmdutil.LookupConversation(ctx, shortName)
|
conversation := cmdutil.LookupConversation(ctx, shortName)
|
||||||
|
|
||||||
var err error
|
var err error
|
||||||
|
var title string
|
||||||
|
|
||||||
generate, _ := cmd.Flags().GetBool("generate")
|
generate, _ := cmd.Flags().GetBool("generate")
|
||||||
var title string
|
|
||||||
if generate {
|
if generate {
|
||||||
title, err = cmdutil.GenerateTitle(ctx, conversation)
|
messages, err := ctx.Store.PathToLeaf(conversation.SelectedRoot)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("Could not retrieve conversation messages: %v", err)
|
||||||
|
}
|
||||||
|
title, err = cmdutil.GenerateTitle(ctx, messages)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("Could not generate conversation title: %v", err)
|
return fmt.Errorf("Could not generate conversation title: %v", err)
|
||||||
}
|
}
|
||||||
@ -41,9 +46,9 @@ func RenameCmd(ctx *lmcli.Context) *cobra.Command {
|
|||||||
}
|
}
|
||||||
|
|
||||||
conversation.Title = title
|
conversation.Title = title
|
||||||
err = ctx.Store.SaveConversation(conversation)
|
err = ctx.Store.UpdateConversation(conversation)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
lmcli.Warn("Could not save conversation with new title: %v\n", err)
|
lmcli.Warn("Could not update conversation title: %v\n", err)
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
},
|
},
|
||||||
|
@ -13,7 +13,7 @@ func RetryCmd(ctx *lmcli.Context) *cobra.Command {
|
|||||||
cmd := &cobra.Command{
|
cmd := &cobra.Command{
|
||||||
Use: "retry <conversation>",
|
Use: "retry <conversation>",
|
||||||
Short: "Retry the last user reply in a conversation",
|
Short: "Retry the last user reply in a conversation",
|
||||||
Long: `Re-prompt the conversation up to the last user response. Can be used to regenerate the last assistant reply, or simply generate one if an error occurred.`,
|
Long: `Prompt the conversation from the last user response.`,
|
||||||
Args: func(cmd *cobra.Command, args []string) error {
|
Args: func(cmd *cobra.Command, args []string) error {
|
||||||
argCount := 1
|
argCount := 1
|
||||||
if err := cobra.MinimumNArgs(argCount)(cmd, args); err != nil {
|
if err := cobra.MinimumNArgs(argCount)(cmd, args); err != nil {
|
||||||
@ -25,25 +25,28 @@ func RetryCmd(ctx *lmcli.Context) *cobra.Command {
|
|||||||
shortName := args[0]
|
shortName := args[0]
|
||||||
conversation := cmdutil.LookupConversation(ctx, shortName)
|
conversation := cmdutil.LookupConversation(ctx, shortName)
|
||||||
|
|
||||||
messages, err := ctx.Store.Messages(conversation)
|
// Load the complete thread from the root message
|
||||||
|
messages, err := ctx.Store.PathToLeaf(conversation.SelectedRoot)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("Could not retrieve messages for conversation: %s", conversation.Title)
|
return fmt.Errorf("Could not retrieve messages for conversation: %s", conversation.Title)
|
||||||
}
|
}
|
||||||
|
|
||||||
// walk backwards through the conversation and delete messages, break
|
// Find the last user message in the conversation
|
||||||
// when we find the latest user response
|
var lastUserMessage *model.Message
|
||||||
for i := len(messages) - 1; i >= 0; i-- {
|
var i int
|
||||||
|
for i = len(messages) - 1; i >= 0; i-- {
|
||||||
if messages[i].Role == model.MessageRoleUser {
|
if messages[i].Role == model.MessageRoleUser {
|
||||||
|
lastUserMessage = &messages[i]
|
||||||
break
|
break
|
||||||
}
|
}
|
||||||
|
|
||||||
err = ctx.Store.DeleteMessage(&messages[i])
|
|
||||||
if err != nil {
|
|
||||||
lmcli.Warn("Could not delete previous reply: %v\n", err)
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
cmdutil.HandleConversationReply(ctx, conversation, true)
|
if lastUserMessage == nil {
|
||||||
|
return fmt.Errorf("No user message found in the conversation: %s", conversation.Title)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Start a new branch at the last user message
|
||||||
|
cmdutil.HandleReply(ctx, lastUserMessage, true)
|
||||||
return nil
|
return nil
|
||||||
},
|
},
|
||||||
ValidArgsFunction: func(cmd *cobra.Command, args []string, toComplete string) ([]string, cobra.ShellCompDirective) {
|
ValidArgsFunction: func(cmd *cobra.Command, args []string, toComplete string) ([]string, cobra.ShellCompDirective) {
|
||||||
|
@ -57,7 +57,7 @@ func LookupConversation(ctx *lmcli.Context, shortName string) *model.Conversatio
|
|||||||
lmcli.Fatal("Could not lookup conversation: %v\n", err)
|
lmcli.Fatal("Could not lookup conversation: %v\n", err)
|
||||||
}
|
}
|
||||||
if c.ID == 0 {
|
if c.ID == 0 {
|
||||||
lmcli.Fatal("Conversation not found with short name: %s\n", shortName)
|
lmcli.Fatal("Conversation not found: %s\n", shortName)
|
||||||
}
|
}
|
||||||
return c
|
return c
|
||||||
}
|
}
|
||||||
@ -68,48 +68,63 @@ func LookupConversationE(ctx *lmcli.Context, shortName string) (*model.Conversat
|
|||||||
return nil, fmt.Errorf("Could not lookup conversation: %v", err)
|
return nil, fmt.Errorf("Could not lookup conversation: %v", err)
|
||||||
}
|
}
|
||||||
if c.ID == 0 {
|
if c.ID == 0 {
|
||||||
return nil, fmt.Errorf("Conversation not found with short name: %s", shortName)
|
return nil, fmt.Errorf("Conversation not found: %s", shortName)
|
||||||
}
|
}
|
||||||
return c, nil
|
return c, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func HandleConversationReply(ctx *lmcli.Context, c *model.Conversation, persist bool, toSend ...model.Message) {
|
||||||
|
messages, err := ctx.Store.PathToLeaf(c.SelectedRoot)
|
||||||
|
if err != nil {
|
||||||
|
lmcli.Fatal("Could not load messages: %v\n", err)
|
||||||
|
}
|
||||||
|
HandleReply(ctx, &messages[len(messages)-1], persist, toSend...)
|
||||||
|
}
|
||||||
|
|
||||||
// handleConversationReply handles sending messages to an existing
|
// handleConversationReply handles sending messages to an existing
|
||||||
// conversation, optionally persisting both the sent replies and responses.
|
// conversation, optionally persisting both the sent replies and responses.
|
||||||
func HandleConversationReply(ctx *lmcli.Context, c *model.Conversation, persist bool, toSend ...model.Message) {
|
func HandleReply(ctx *lmcli.Context, to *model.Message, persist bool, messages ...model.Message) {
|
||||||
existing, err := ctx.Store.Messages(c)
|
if to == nil {
|
||||||
|
lmcli.Fatal("Can't prompt from an empty message.")
|
||||||
|
}
|
||||||
|
|
||||||
|
existing, err := ctx.Store.PathToRoot(to)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
lmcli.Fatal("Could not retrieve messages for conversation: %s\n", c.Title)
|
lmcli.Fatal("Could not load messages: %v\n", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
if persist {
|
RenderConversation(ctx, append(existing, messages...), true)
|
||||||
for _, message := range toSend {
|
|
||||||
err = ctx.Store.SaveMessage(&message)
|
var savedReplies []model.Message
|
||||||
|
if persist && len(messages) > 0 {
|
||||||
|
savedReplies, err = ctx.Store.Reply(to, messages...)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
lmcli.Warn("Could not save %s message: %v\n", message.Role, err)
|
lmcli.Warn("Could not save messages: %v\n", err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
|
||||||
|
|
||||||
allMessages := append(existing, toSend...)
|
|
||||||
|
|
||||||
RenderConversation(ctx, allMessages, true)
|
|
||||||
|
|
||||||
// render a message header with no contents
|
// render a message header with no contents
|
||||||
RenderMessage(ctx, (&model.Message{Role: model.MessageRoleAssistant}))
|
RenderMessage(ctx, (&model.Message{Role: model.MessageRoleAssistant}))
|
||||||
|
|
||||||
|
var lastMessage *model.Message
|
||||||
|
|
||||||
|
lastMessage = to
|
||||||
|
if len(savedReplies) > 1 {
|
||||||
|
lastMessage = &savedReplies[len(savedReplies)-1]
|
||||||
|
}
|
||||||
|
|
||||||
replyCallback := func(reply model.Message) {
|
replyCallback := func(reply model.Message) {
|
||||||
if !persist {
|
if !persist {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
savedReplies, err = ctx.Store.Reply(lastMessage, reply)
|
||||||
reply.ConversationID = c.ID
|
|
||||||
err = ctx.Store.SaveMessage(&reply)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
lmcli.Warn("Could not save reply: %v\n", err)
|
lmcli.Warn("Could not save reply: %v\n", err)
|
||||||
}
|
}
|
||||||
|
lastMessage = &savedReplies[0]
|
||||||
}
|
}
|
||||||
|
|
||||||
_, err = Prompt(ctx, allMessages, replyCallback)
|
_, err = Prompt(ctx, append(existing, messages...), replyCallback)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
lmcli.Fatal("Error fetching LLM response: %v\n", err)
|
lmcli.Fatal("Error fetching LLM response: %v\n", err)
|
||||||
}
|
}
|
||||||
@ -134,12 +149,7 @@ func FormatForExternalPrompt(messages []model.Message, system bool) string {
|
|||||||
return sb.String()
|
return sb.String()
|
||||||
}
|
}
|
||||||
|
|
||||||
func GenerateTitle(ctx *lmcli.Context, c *model.Conversation) (string, error) {
|
func GenerateTitle(ctx *lmcli.Context, messages []model.Message) (string, error) {
|
||||||
messages, err := ctx.Store.Messages(c)
|
|
||||||
if err != nil {
|
|
||||||
return "", err
|
|
||||||
}
|
|
||||||
|
|
||||||
const prompt = `Above is an excerpt from a conversation between a user and AI assistant. Please reply with a short title (no more than 8 words) that reflects the topic of the conversation, read from the user's perspective.
|
const prompt = `Above is an excerpt from a conversation between a user and AI assistant. Please reply with a short title (no more than 8 words) that reflects the topic of the conversation, read from the user's perspective.
|
||||||
|
|
||||||
Example conversation:
|
Example conversation:
|
||||||
@ -186,6 +196,7 @@ Title: A brief introduction
|
|||||||
|
|
||||||
response = strings.TrimPrefix(response, "Title: ")
|
response = strings.TrimPrefix(response, "Title: ")
|
||||||
response = strings.Trim(response, "\"")
|
response = strings.Trim(response, "\"")
|
||||||
|
response = strings.TrimSpace(response)
|
||||||
|
|
||||||
return response, nil
|
return response, nil
|
||||||
}
|
}
|
||||||
|
@ -24,9 +24,9 @@ func ViewCmd(ctx *lmcli.Context) *cobra.Command {
|
|||||||
shortName := args[0]
|
shortName := args[0]
|
||||||
conversation := cmdutil.LookupConversation(ctx, shortName)
|
conversation := cmdutil.LookupConversation(ctx, shortName)
|
||||||
|
|
||||||
messages, err := ctx.Store.Messages(conversation)
|
messages, err := ctx.Store.PathToLeaf(conversation.SelectedRoot)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("Could not retrieve messages for conversation: %s", conversation.Title)
|
return fmt.Errorf("Could not retrieve messages for conversation %s: %v", conversation.ShortName.String, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
cmdutil.RenderConversation(ctx, messages, false)
|
cmdutil.RenderConversation(ctx, messages, false)
|
||||||
|
@ -22,6 +22,7 @@ type Config struct {
|
|||||||
EnabledTools []string `yaml:"enabledTools"`
|
EnabledTools []string `yaml:"enabledTools"`
|
||||||
} `yaml:"tools"`
|
} `yaml:"tools"`
|
||||||
Providers []*struct {
|
Providers []*struct {
|
||||||
|
Name *string `yaml:"name"`
|
||||||
Kind *string `yaml:"kind"`
|
Kind *string `yaml:"kind"`
|
||||||
BaseURL *string `yaml:"baseUrl"`
|
BaseURL *string `yaml:"baseUrl"`
|
||||||
APIKey *string `yaml:"apiKey"`
|
APIKey *string `yaml:"apiKey"`
|
||||||
|
@ -4,10 +4,12 @@ import (
|
|||||||
"fmt"
|
"fmt"
|
||||||
"os"
|
"os"
|
||||||
"path/filepath"
|
"path/filepath"
|
||||||
|
"strings"
|
||||||
|
|
||||||
"git.mlow.ca/mlow/lmcli/pkg/lmcli/model"
|
"git.mlow.ca/mlow/lmcli/pkg/lmcli/model"
|
||||||
"git.mlow.ca/mlow/lmcli/pkg/lmcli/provider"
|
"git.mlow.ca/mlow/lmcli/pkg/lmcli/provider"
|
||||||
"git.mlow.ca/mlow/lmcli/pkg/lmcli/provider/anthropic"
|
"git.mlow.ca/mlow/lmcli/pkg/lmcli/provider/anthropic"
|
||||||
|
"git.mlow.ca/mlow/lmcli/pkg/lmcli/provider/google"
|
||||||
"git.mlow.ca/mlow/lmcli/pkg/lmcli/provider/openai"
|
"git.mlow.ca/mlow/lmcli/pkg/lmcli/provider/openai"
|
||||||
"git.mlow.ca/mlow/lmcli/pkg/lmcli/tools"
|
"git.mlow.ca/mlow/lmcli/pkg/lmcli/tools"
|
||||||
"git.mlow.ca/mlow/lmcli/pkg/util"
|
"git.mlow.ca/mlow/lmcli/pkg/util"
|
||||||
@ -34,7 +36,9 @@ func NewContext() (*Context, error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
databaseFile := filepath.Join(dataDir(), "conversations.db")
|
databaseFile := filepath.Join(dataDir(), "conversations.db")
|
||||||
db, err := gorm.Open(sqlite.Open(databaseFile), &gorm.Config{})
|
db, err := gorm.Open(sqlite.Open(databaseFile), &gorm.Config{
|
||||||
|
//Logger: logger.Default.LogMode(logger.Info),
|
||||||
|
})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("Error establishing connection to store: %v", err)
|
return nil, fmt.Errorf("Error establishing connection to store: %v", err)
|
||||||
}
|
}
|
||||||
@ -57,16 +61,37 @@ func NewContext() (*Context, error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (c *Context) GetModels() (models []string) {
|
func (c *Context) GetModels() (models []string) {
|
||||||
|
modelCounts := make(map[string]int)
|
||||||
for _, p := range c.Config.Providers {
|
for _, p := range c.Config.Providers {
|
||||||
for _, m := range *p.Models {
|
for _, m := range *p.Models {
|
||||||
|
modelCounts[m]++
|
||||||
|
models = append(models, *p.Name+"/"+m)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
for m, c := range modelCounts {
|
||||||
|
if c == 1 {
|
||||||
models = append(models, m)
|
models = append(models, m)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *Context) GetCompletionProvider(model string) (provider.ChatCompletionClient, error) {
|
func (c *Context) GetCompletionProvider(model string) (provider.ChatCompletionClient, error) {
|
||||||
|
parts := strings.Split(model, "/")
|
||||||
|
|
||||||
|
var provider string
|
||||||
|
if len(parts) > 1 {
|
||||||
|
provider = parts[0]
|
||||||
|
model = parts[1]
|
||||||
|
}
|
||||||
|
|
||||||
for _, p := range c.Config.Providers {
|
for _, p := range c.Config.Providers {
|
||||||
|
if provider != "" && *p.Name != provider {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
for _, m := range *p.Models {
|
for _, m := range *p.Models {
|
||||||
if m == model {
|
if m == model {
|
||||||
switch *p.Kind {
|
switch *p.Kind {
|
||||||
@ -75,21 +100,28 @@ func (c *Context) GetCompletionProvider(model string) (provider.ChatCompletionCl
|
|||||||
if p.BaseURL != nil {
|
if p.BaseURL != nil {
|
||||||
url = *p.BaseURL
|
url = *p.BaseURL
|
||||||
}
|
}
|
||||||
anthropic := &anthropic.AnthropicClient{
|
return &anthropic.AnthropicClient{
|
||||||
BaseURL: url,
|
BaseURL: url,
|
||||||
APIKey: *p.APIKey,
|
APIKey: *p.APIKey,
|
||||||
|
}, nil
|
||||||
|
case "google":
|
||||||
|
url := "https://generativelanguage.googleapis.com"
|
||||||
|
if p.BaseURL != nil {
|
||||||
|
url = *p.BaseURL
|
||||||
}
|
}
|
||||||
return anthropic, nil
|
return &google.Client{
|
||||||
|
BaseURL: url,
|
||||||
|
APIKey: *p.APIKey,
|
||||||
|
}, nil
|
||||||
case "openai":
|
case "openai":
|
||||||
url := "https://api.openai.com/v1"
|
url := "https://api.openai.com/v1"
|
||||||
if p.BaseURL != nil {
|
if p.BaseURL != nil {
|
||||||
url = *p.BaseURL
|
url = *p.BaseURL
|
||||||
}
|
}
|
||||||
openai := &openai.OpenAIClient{
|
return &openai.OpenAIClient{
|
||||||
BaseURL: url,
|
BaseURL: url,
|
||||||
APIKey: *p.APIKey,
|
APIKey: *p.APIKey,
|
||||||
}
|
}, nil
|
||||||
return openai, nil
|
|
||||||
default:
|
default:
|
||||||
return nil, fmt.Errorf("unknown provider kind: %s", *p.Kind)
|
return nil, fmt.Errorf("unknown provider kind: %s", *p.Kind)
|
||||||
}
|
}
|
||||||
|
@ -17,18 +17,27 @@ const (
|
|||||||
|
|
||||||
type Message struct {
|
type Message struct {
|
||||||
ID uint `gorm:"primaryKey"`
|
ID uint `gorm:"primaryKey"`
|
||||||
ConversationID uint `gorm:"foreignKey:ConversationID"`
|
ConversationID uint `gorm:"index"`
|
||||||
|
Conversation Conversation `gorm:"foreignKey:ConversationID"`
|
||||||
Content string
|
Content string
|
||||||
Role MessageRole
|
Role MessageRole
|
||||||
CreatedAt time.Time
|
CreatedAt time.Time `gorm:"index"`
|
||||||
ToolCalls ToolCalls // a json array of tool calls (from the modl)
|
ToolCalls ToolCalls // a json array of tool calls (from the model)
|
||||||
ToolResults ToolResults // a json array of tool results
|
ToolResults ToolResults // a json array of tool results
|
||||||
|
ParentID *uint `gorm:"index"`
|
||||||
|
Parent *Message `gorm:"foreignKey:ParentID"`
|
||||||
|
Replies []Message `gorm:"foreignKey:ParentID"`
|
||||||
|
|
||||||
|
SelectedReplyID *uint
|
||||||
|
SelectedReply *Message `gorm:"foreignKey:SelectedReplyID"`
|
||||||
}
|
}
|
||||||
|
|
||||||
type Conversation struct {
|
type Conversation struct {
|
||||||
ID uint `gorm:"primaryKey"`
|
ID uint `gorm:"primaryKey"`
|
||||||
ShortName sql.NullString
|
ShortName sql.NullString
|
||||||
Title string
|
Title string
|
||||||
|
SelectedRootID *uint
|
||||||
|
SelectedRoot *Message `gorm:"foreignKey:SelectedRootID"`
|
||||||
}
|
}
|
||||||
|
|
||||||
type RequestParameters struct {
|
type RequestParameters struct {
|
||||||
|
413
pkg/lmcli/provider/google/google.go
Normal file
413
pkg/lmcli/provider/google/google.go
Normal file
@ -0,0 +1,413 @@
|
|||||||
|
package google
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bufio"
|
||||||
|
"bytes"
|
||||||
|
"context"
|
||||||
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
|
"io"
|
||||||
|
"net/http"
|
||||||
|
"strings"
|
||||||
|
|
||||||
|
"git.mlow.ca/mlow/lmcli/pkg/lmcli/model"
|
||||||
|
"git.mlow.ca/mlow/lmcli/pkg/lmcli/provider"
|
||||||
|
"git.mlow.ca/mlow/lmcli/pkg/lmcli/tools"
|
||||||
|
)
|
||||||
|
|
||||||
|
func convertTools(tools []model.Tool) []Tool {
|
||||||
|
geminiTools := make([]Tool, len(tools))
|
||||||
|
for i, tool := range tools {
|
||||||
|
params := make(map[string]ToolParameter)
|
||||||
|
var required []string
|
||||||
|
|
||||||
|
for _, param := range tool.Parameters {
|
||||||
|
// TODO: proper enum handing
|
||||||
|
params[param.Name] = ToolParameter{
|
||||||
|
Type: param.Type,
|
||||||
|
Description: param.Description,
|
||||||
|
Values: param.Enum,
|
||||||
|
}
|
||||||
|
if param.Required {
|
||||||
|
required = append(required, param.Name)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
geminiTools[i] = Tool{
|
||||||
|
FunctionDeclarations: []FunctionDeclaration{
|
||||||
|
{
|
||||||
|
Name: tool.Name,
|
||||||
|
Description: tool.Description,
|
||||||
|
Parameters: ToolParameters{
|
||||||
|
Type: "OBJECT",
|
||||||
|
Properties: params,
|
||||||
|
Required: required,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return geminiTools
|
||||||
|
}
|
||||||
|
|
||||||
|
func convertToolCallToGemini(toolCalls []model.ToolCall) []ContentPart {
|
||||||
|
converted := make([]ContentPart, len(toolCalls))
|
||||||
|
for i, call := range toolCalls {
|
||||||
|
args := make(map[string]string)
|
||||||
|
for k, v := range call.Parameters {
|
||||||
|
args[k] = fmt.Sprintf("%v", v)
|
||||||
|
}
|
||||||
|
converted[i].FunctionCall = &FunctionCall{
|
||||||
|
Name: call.Name,
|
||||||
|
Args: args,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return converted
|
||||||
|
}
|
||||||
|
|
||||||
|
func convertToolCallToAPI(functionCalls []FunctionCall) []model.ToolCall {
|
||||||
|
converted := make([]model.ToolCall, len(functionCalls))
|
||||||
|
for i, call := range functionCalls {
|
||||||
|
params := make(map[string]interface{})
|
||||||
|
for k, v := range call.Args {
|
||||||
|
params[k] = v
|
||||||
|
}
|
||||||
|
converted[i].Name = call.Name
|
||||||
|
converted[i].Parameters = params
|
||||||
|
}
|
||||||
|
return converted
|
||||||
|
}
|
||||||
|
|
||||||
|
func convertToolResultsToGemini(toolResults []model.ToolResult) ([]FunctionResponse, error) {
|
||||||
|
results := make([]FunctionResponse, len(toolResults))
|
||||||
|
for i, result := range toolResults {
|
||||||
|
var obj interface{}
|
||||||
|
err := json.Unmarshal([]byte(result.Result), &obj)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("Could not unmarshal %s: %v", result.Result, err)
|
||||||
|
}
|
||||||
|
results[i] = FunctionResponse{
|
||||||
|
Name: result.ToolName,
|
||||||
|
Response: obj,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return results, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func createGenerateContentRequest(
|
||||||
|
params model.RequestParameters,
|
||||||
|
messages []model.Message,
|
||||||
|
) (*GenerateContentRequest, error) {
|
||||||
|
requestContents := make([]Content, 0, len(messages))
|
||||||
|
|
||||||
|
startIdx := 0
|
||||||
|
var system string
|
||||||
|
if len(messages) > 0 && messages[0].Role == model.MessageRoleSystem {
|
||||||
|
system = messages[0].Content
|
||||||
|
startIdx = 1
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, m := range messages[startIdx:] {
|
||||||
|
switch m.Role {
|
||||||
|
case "tool_call":
|
||||||
|
content := Content{
|
||||||
|
Role: "model",
|
||||||
|
Parts: convertToolCallToGemini(m.ToolCalls),
|
||||||
|
}
|
||||||
|
requestContents = append(requestContents, content)
|
||||||
|
case "tool_result":
|
||||||
|
results, err := convertToolResultsToGemini(m.ToolResults)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
// expand tool_result messages' results into multiple gemini messages
|
||||||
|
for _, result := range results {
|
||||||
|
content := Content{
|
||||||
|
Role: "function",
|
||||||
|
Parts: []ContentPart{
|
||||||
|
{
|
||||||
|
FunctionResp: &result,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
requestContents = append(requestContents, content)
|
||||||
|
}
|
||||||
|
default:
|
||||||
|
var role string
|
||||||
|
switch m.Role {
|
||||||
|
case model.MessageRoleAssistant:
|
||||||
|
role = "model"
|
||||||
|
case model.MessageRoleUser:
|
||||||
|
role = "user"
|
||||||
|
}
|
||||||
|
|
||||||
|
if role == "" {
|
||||||
|
panic("Unhandled role: " + m.Role)
|
||||||
|
}
|
||||||
|
|
||||||
|
content := Content{
|
||||||
|
Role: role,
|
||||||
|
Parts: []ContentPart{
|
||||||
|
{
|
||||||
|
Text: m.Content,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
requestContents = append(requestContents, content)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
request := &GenerateContentRequest{
|
||||||
|
Contents: requestContents,
|
||||||
|
SystemInstructions: system,
|
||||||
|
GenerationConfig: &GenerationConfig{
|
||||||
|
MaxOutputTokens: ¶ms.MaxTokens,
|
||||||
|
Temperature: ¶ms.Temperature,
|
||||||
|
TopP: ¶ms.TopP,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(params.ToolBag) > 0 {
|
||||||
|
request.Tools = convertTools(params.ToolBag)
|
||||||
|
}
|
||||||
|
|
||||||
|
return request, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func handleToolCalls(
|
||||||
|
params model.RequestParameters,
|
||||||
|
content string,
|
||||||
|
toolCalls []model.ToolCall,
|
||||||
|
callback provider.ReplyCallback,
|
||||||
|
messages []model.Message,
|
||||||
|
) ([]model.Message, error) {
|
||||||
|
lastMessage := messages[len(messages)-1]
|
||||||
|
continuation := false
|
||||||
|
if lastMessage.Role.IsAssistant() {
|
||||||
|
continuation = true
|
||||||
|
}
|
||||||
|
|
||||||
|
toolCall := model.Message{
|
||||||
|
Role: model.MessageRoleToolCall,
|
||||||
|
Content: content,
|
||||||
|
ToolCalls: toolCalls,
|
||||||
|
}
|
||||||
|
|
||||||
|
toolResults, err := tools.ExecuteToolCalls(toolCall.ToolCalls, params.ToolBag)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
toolResult := model.Message{
|
||||||
|
Role: model.MessageRoleToolResult,
|
||||||
|
ToolResults: toolResults,
|
||||||
|
}
|
||||||
|
|
||||||
|
if callback != nil {
|
||||||
|
callback(toolCall)
|
||||||
|
callback(toolResult)
|
||||||
|
}
|
||||||
|
|
||||||
|
if continuation {
|
||||||
|
messages[len(messages)-1] = toolCall
|
||||||
|
} else {
|
||||||
|
messages = append(messages, toolCall)
|
||||||
|
}
|
||||||
|
messages = append(messages, toolResult)
|
||||||
|
|
||||||
|
return messages, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *Client) sendRequest(ctx context.Context, req *http.Request) (*http.Response, error) {
|
||||||
|
req.Header.Set("Content-Type", "application/json")
|
||||||
|
|
||||||
|
client := &http.Client{}
|
||||||
|
resp, err := client.Do(req.WithContext(ctx))
|
||||||
|
|
||||||
|
if resp.StatusCode != 200 {
|
||||||
|
bytes, _ := io.ReadAll(resp.Body)
|
||||||
|
return resp, fmt.Errorf("%v", string(bytes))
|
||||||
|
}
|
||||||
|
|
||||||
|
return resp, err
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *Client) CreateChatCompletion(
|
||||||
|
ctx context.Context,
|
||||||
|
params model.RequestParameters,
|
||||||
|
messages []model.Message,
|
||||||
|
callback provider.ReplyCallback,
|
||||||
|
) (string, error) {
|
||||||
|
if len(messages) == 0 {
|
||||||
|
return "", fmt.Errorf("Can't create completion from no messages")
|
||||||
|
}
|
||||||
|
|
||||||
|
req, err := createGenerateContentRequest(params, messages)
|
||||||
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
jsonData, err := json.Marshal(req)
|
||||||
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
|
||||||
|
url := fmt.Sprintf(
|
||||||
|
"%s/v1beta/models/%s:generateContent?key=%s",
|
||||||
|
c.BaseURL, params.Model, c.APIKey,
|
||||||
|
)
|
||||||
|
httpReq, err := http.NewRequest("POST", url, bytes.NewBuffer(jsonData))
|
||||||
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
|
||||||
|
resp, err := c.sendRequest(ctx, httpReq)
|
||||||
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
defer resp.Body.Close()
|
||||||
|
|
||||||
|
var completionResp GenerateContentResponse
|
||||||
|
err = json.NewDecoder(resp.Body).Decode(&completionResp)
|
||||||
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
|
||||||
|
choice := completionResp.Candidates[0]
|
||||||
|
|
||||||
|
var content string
|
||||||
|
lastMessage := messages[len(messages)-1]
|
||||||
|
if lastMessage.Role.IsAssistant() {
|
||||||
|
content = lastMessage.Content
|
||||||
|
}
|
||||||
|
|
||||||
|
var toolCalls []FunctionCall
|
||||||
|
for _, part := range choice.Content.Parts {
|
||||||
|
if part.Text != "" {
|
||||||
|
content += part.Text
|
||||||
|
}
|
||||||
|
|
||||||
|
if part.FunctionCall != nil {
|
||||||
|
toolCalls = append(toolCalls, *part.FunctionCall)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(toolCalls) > 0 {
|
||||||
|
messages, err := handleToolCalls(
|
||||||
|
params, content, convertToolCallToAPI(toolCalls), callback, messages,
|
||||||
|
)
|
||||||
|
if err != nil {
|
||||||
|
return content, err
|
||||||
|
}
|
||||||
|
|
||||||
|
return c.CreateChatCompletion(ctx, params, messages, callback)
|
||||||
|
}
|
||||||
|
|
||||||
|
if callback != nil {
|
||||||
|
callback(model.Message{
|
||||||
|
Role: model.MessageRoleAssistant,
|
||||||
|
Content: content,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
return content, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *Client) CreateChatCompletionStream(
|
||||||
|
ctx context.Context,
|
||||||
|
params model.RequestParameters,
|
||||||
|
messages []model.Message,
|
||||||
|
callback provider.ReplyCallback,
|
||||||
|
output chan<- string,
|
||||||
|
) (string, error) {
|
||||||
|
if len(messages) == 0 {
|
||||||
|
return "", fmt.Errorf("Can't create completion from no messages")
|
||||||
|
}
|
||||||
|
|
||||||
|
req, err := createGenerateContentRequest(params, messages)
|
||||||
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
jsonData, err := json.Marshal(req)
|
||||||
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
|
||||||
|
url := fmt.Sprintf(
|
||||||
|
"%s/v1beta/models/%s:streamGenerateContent?key=%s&alt=sse",
|
||||||
|
c.BaseURL, params.Model, c.APIKey,
|
||||||
|
)
|
||||||
|
httpReq, err := http.NewRequest("POST", url, bytes.NewBuffer(jsonData))
|
||||||
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
|
||||||
|
resp, err := c.sendRequest(ctx, httpReq)
|
||||||
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
defer resp.Body.Close()
|
||||||
|
|
||||||
|
content := strings.Builder{}
|
||||||
|
|
||||||
|
lastMessage := messages[len(messages)-1]
|
||||||
|
if lastMessage.Role.IsAssistant() {
|
||||||
|
content.WriteString(lastMessage.Content)
|
||||||
|
}
|
||||||
|
|
||||||
|
var toolCalls []FunctionCall
|
||||||
|
|
||||||
|
reader := bufio.NewReader(resp.Body)
|
||||||
|
for {
|
||||||
|
line, err := reader.ReadBytes('\n')
|
||||||
|
if err != nil {
|
||||||
|
if err == io.EOF {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
|
||||||
|
line = bytes.TrimSpace(line)
|
||||||
|
if len(line) == 0 || !bytes.HasPrefix(line, []byte("data: ")) {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
line = bytes.TrimPrefix(line, []byte("data: "))
|
||||||
|
|
||||||
|
var streamResp GenerateContentResponse
|
||||||
|
err = json.Unmarshal(line, &streamResp)
|
||||||
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, candidate := range streamResp.Candidates {
|
||||||
|
for _, part := range candidate.Content.Parts {
|
||||||
|
if part.FunctionCall != nil {
|
||||||
|
toolCalls = append(toolCalls, *part.FunctionCall)
|
||||||
|
} else if part.Text != "" {
|
||||||
|
output <- part.Text
|
||||||
|
content.WriteString(part.Text)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// If there are function calls, handle them and recurse
|
||||||
|
if len(toolCalls) > 0 {
|
||||||
|
messages, err := handleToolCalls(
|
||||||
|
params, content.String(), convertToolCallToAPI(toolCalls), callback, messages,
|
||||||
|
)
|
||||||
|
if err != nil {
|
||||||
|
return content.String(), err
|
||||||
|
}
|
||||||
|
return c.CreateChatCompletionStream(ctx, params, messages, callback, output)
|
||||||
|
}
|
||||||
|
|
||||||
|
if callback != nil {
|
||||||
|
callback(model.Message{
|
||||||
|
Role: model.MessageRoleAssistant,
|
||||||
|
Content: content.String(),
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
return content.String(), nil
|
||||||
|
}
|
80
pkg/lmcli/provider/google/types.go
Normal file
80
pkg/lmcli/provider/google/types.go
Normal file
@ -0,0 +1,80 @@
|
|||||||
|
package google
|
||||||
|
|
||||||
|
type Client struct {
|
||||||
|
APIKey string
|
||||||
|
BaseURL string
|
||||||
|
}
|
||||||
|
|
||||||
|
type ContentPart struct {
|
||||||
|
Text string `json:"text,omitempty"`
|
||||||
|
FunctionCall *FunctionCall `json:"functionCall,omitempty"`
|
||||||
|
FunctionResp *FunctionResponse `json:"functionResponse,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type FunctionCall struct {
|
||||||
|
Name string `json:"name"`
|
||||||
|
Args map[string]string `json:"args"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type FunctionResponse struct {
|
||||||
|
Name string `json:"name"`
|
||||||
|
Response interface{} `json:"response"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type Content struct {
|
||||||
|
Role string `json:"role"`
|
||||||
|
Parts []ContentPart `json:"parts"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type GenerationConfig struct {
|
||||||
|
MaxOutputTokens *int `json:"maxOutputTokens,omitempty"`
|
||||||
|
Temperature *float32 `json:"temperature,omitempty"`
|
||||||
|
TopP *float32 `json:"topP,omitempty"`
|
||||||
|
TopK *int `json:"topK,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type GenerateContentRequest struct {
|
||||||
|
Contents []Content `json:"contents"`
|
||||||
|
Tools []Tool `json:"tools,omitempty"`
|
||||||
|
SystemInstructions string `json:"systemInstructions,omitempty"`
|
||||||
|
GenerationConfig *GenerationConfig `json:"generationConfig,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type Candidate struct {
|
||||||
|
Content Content `json:"content"`
|
||||||
|
FinishReason string `json:"finishReason"`
|
||||||
|
Index int `json:"index"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type UsageMetadata struct {
|
||||||
|
PromptTokenCount int `json:"promptTokenCount"`
|
||||||
|
CandidatesTokenCount int `json:"candidatesTokenCount"`
|
||||||
|
TotalTokenCount int `json:"totalTokenCount"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type GenerateContentResponse struct {
|
||||||
|
Candidates []Candidate `json:"candidates"`
|
||||||
|
UsageMetadata UsageMetadata `json:"usageMetadata"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type Tool struct {
|
||||||
|
FunctionDeclarations []FunctionDeclaration `json:"functionDeclarations"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type FunctionDeclaration struct {
|
||||||
|
Name string `json:"name"`
|
||||||
|
Description string `json:"description"`
|
||||||
|
Parameters ToolParameters `json:"parameters"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type ToolParameters struct {
|
||||||
|
Type string `json:"type"`
|
||||||
|
Properties map[string]ToolParameter `json:"properties,omitempty"`
|
||||||
|
Required []string `json:"required,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type ToolParameter struct {
|
||||||
|
Type string `json:"type"`
|
||||||
|
Description string `json:"description"`
|
||||||
|
Values []string `json:"values,omitempty"`
|
||||||
|
}
|
@ -13,21 +13,26 @@ import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
type ConversationStore interface {
|
type ConversationStore interface {
|
||||||
Conversations() ([]model.Conversation, error)
|
|
||||||
|
|
||||||
ConversationByShortName(shortName string) (*model.Conversation, error)
|
ConversationByShortName(shortName string) (*model.Conversation, error)
|
||||||
ConversationShortNameCompletions(search string) []string
|
ConversationShortNameCompletions(search string) []string
|
||||||
|
RootMessages(conversationID uint) ([]model.Message, error)
|
||||||
|
LatestConversationMessages() ([]model.Message, error)
|
||||||
|
|
||||||
SaveConversation(conversation *model.Conversation) error
|
StartConversation(messages ...model.Message) (*model.Conversation, []model.Message, error)
|
||||||
|
UpdateConversation(conversation *model.Conversation) error
|
||||||
DeleteConversation(conversation *model.Conversation) error
|
DeleteConversation(conversation *model.Conversation) error
|
||||||
|
CloneConversation(toClone model.Conversation) (*model.Conversation, uint, error)
|
||||||
|
|
||||||
Messages(conversation *model.Conversation) ([]model.Message, error)
|
MessageByID(messageID uint) (*model.Message, error)
|
||||||
LastMessage(conversation *model.Conversation) (*model.Message, error)
|
MessageReplies(messageID uint) ([]model.Message, error)
|
||||||
|
|
||||||
SaveMessage(message *model.Message) error
|
|
||||||
DeleteMessage(message *model.Message) error
|
|
||||||
UpdateMessage(message *model.Message) error
|
UpdateMessage(message *model.Message) error
|
||||||
AddReply(conversation *model.Conversation, message model.Message) (*model.Message, error)
|
DeleteMessage(message *model.Message, prune bool) error
|
||||||
|
CloneBranch(toClone model.Message) (*model.Message, uint, error)
|
||||||
|
Reply(to *model.Message, messages ...model.Message) ([]model.Message, error)
|
||||||
|
|
||||||
|
PathToRoot(message *model.Message) ([]model.Message, error)
|
||||||
|
PathToLeaf(message *model.Message) ([]model.Message, error)
|
||||||
}
|
}
|
||||||
|
|
||||||
type SQLStore struct {
|
type SQLStore struct {
|
||||||
@ -52,47 +57,52 @@ func NewSQLStore(db *gorm.DB) (*SQLStore, error) {
|
|||||||
return &SQLStore{db, _sqids}, nil
|
return &SQLStore{db, _sqids}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *SQLStore) SaveConversation(conversation *model.Conversation) error {
|
func (s *SQLStore) saveNewConversation(c *model.Conversation) error {
|
||||||
err := s.db.Save(&conversation).Error
|
// Save the new conversation
|
||||||
|
err := s.db.Save(&c).Error
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
if !conversation.ShortName.Valid {
|
// Generate and save its "short name"
|
||||||
shortName, _ := s.sqids.Encode([]uint64{uint64(conversation.ID)})
|
shortName, _ := s.sqids.Encode([]uint64{uint64(c.ID)})
|
||||||
conversation.ShortName = sql.NullString{String: shortName, Valid: true}
|
c.ShortName = sql.NullString{String: shortName, Valid: true}
|
||||||
err = s.db.Save(&conversation).Error
|
return s.UpdateConversation(c)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *SQLStore) UpdateConversation(c *model.Conversation) error {
|
||||||
|
if c == nil || c.ID == 0 {
|
||||||
|
return fmt.Errorf("Conversation is nil or invalid (missing ID)")
|
||||||
}
|
}
|
||||||
|
return s.db.Updates(&c).Error
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *SQLStore) DeleteConversation(c *model.Conversation) error {
|
||||||
|
// Delete messages first
|
||||||
|
err := s.db.Where("conversation_id = ?", c.ID).Delete(&model.Message{}).Error
|
||||||
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
|
}
|
||||||
|
return s.db.Delete(&c).Error
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *SQLStore) DeleteConversation(conversation *model.Conversation) error {
|
func (s *SQLStore) DeleteMessage(message *model.Message, prune bool) error {
|
||||||
s.db.Where("conversation_id = ?", conversation.ID).Delete(&model.Message{})
|
panic("Not yet implemented")
|
||||||
return s.db.Delete(&conversation).Error
|
//return s.db.Delete(&message).Error
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *SQLStore) SaveMessage(message *model.Message) error {
|
func (s *SQLStore) UpdateMessage(m *model.Message) error {
|
||||||
return s.db.Create(message).Error
|
if m == nil || m.ID == 0 {
|
||||||
}
|
return fmt.Errorf("Message is nil or invalid (missing ID)")
|
||||||
|
}
|
||||||
func (s *SQLStore) DeleteMessage(message *model.Message) error {
|
return s.db.Updates(&m).Error
|
||||||
return s.db.Delete(&message).Error
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *SQLStore) UpdateMessage(message *model.Message) error {
|
|
||||||
return s.db.Updates(&message).Error
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *SQLStore) Conversations() ([]model.Conversation, error) {
|
|
||||||
var conversations []model.Conversation
|
|
||||||
err := s.db.Find(&conversations).Error
|
|
||||||
return conversations, err
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *SQLStore) ConversationShortNameCompletions(shortName string) []string {
|
func (s *SQLStore) ConversationShortNameCompletions(shortName string) []string {
|
||||||
var completions []string
|
var conversations []model.Conversation
|
||||||
conversations, _ := s.Conversations() // ignore error for completions
|
// ignore error for completions
|
||||||
|
s.db.Find(&conversations)
|
||||||
|
completions := make([]string, 0, len(conversations))
|
||||||
for _, conversation := range conversations {
|
for _, conversation := range conversations {
|
||||||
if shortName == "" || strings.HasPrefix(conversation.ShortName.String, shortName) {
|
if shortName == "" || strings.HasPrefix(conversation.ShortName.String, shortName) {
|
||||||
completions = append(completions, fmt.Sprintf("%s\t%s", conversation.ShortName.String, conversation.Title))
|
completions = append(completions, fmt.Sprintf("%s\t%s", conversation.ShortName.String, conversation.Title))
|
||||||
@ -106,27 +116,249 @@ func (s *SQLStore) ConversationByShortName(shortName string) (*model.Conversatio
|
|||||||
return nil, errors.New("shortName is empty")
|
return nil, errors.New("shortName is empty")
|
||||||
}
|
}
|
||||||
var conversation model.Conversation
|
var conversation model.Conversation
|
||||||
err := s.db.Where("short_name = ?", shortName).Find(&conversation).Error
|
err := s.db.Preload("SelectedRoot").Where("short_name = ?", shortName).Find(&conversation).Error
|
||||||
return &conversation, err
|
return &conversation, err
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *SQLStore) Messages(conversation *model.Conversation) ([]model.Message, error) {
|
func (s *SQLStore) RootMessages(conversationID uint) ([]model.Message, error) {
|
||||||
var messages []model.Message
|
var rootMessages []model.Message
|
||||||
err := s.db.Where("conversation_id = ?", conversation.ID).Find(&messages).Error
|
err := s.db.Where("conversation_id = ? AND parent_id IS NULL", conversationID).Find(&rootMessages).Error
|
||||||
return messages, err
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return rootMessages, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *SQLStore) LastMessage(conversation *model.Conversation) (*model.Message, error) {
|
func (s *SQLStore) MessageByID(messageID uint) (*model.Message, error) {
|
||||||
var message model.Message
|
var message model.Message
|
||||||
err := s.db.Where("conversation_id = ?", conversation.ID).Last(&message).Error
|
err := s.db.Preload("Parent").Preload("Replies").Preload("SelectedReply").Where("id = ?", messageID).Find(&message).Error
|
||||||
return &message, err
|
return &message, err
|
||||||
}
|
}
|
||||||
|
|
||||||
// AddReply adds the given messages as a reply to the given conversation, can be
|
func (s *SQLStore) MessageReplies(messageID uint) ([]model.Message, error) {
|
||||||
// used to easily copy a message associated with one conversation, to another
|
var replies []model.Message
|
||||||
func (s *SQLStore) AddReply(c *model.Conversation, m model.Message) (*model.Message, error) {
|
err := s.db.Where("parent_id = ?", messageID).Find(&replies).Error
|
||||||
m.ConversationID = c.ID
|
return replies, err
|
||||||
m.ID = 0
|
}
|
||||||
m.CreatedAt = time.Time{}
|
|
||||||
return &m, s.SaveMessage(&m)
|
// StartConversation starts a new conversation with the provided messages
|
||||||
|
func (s *SQLStore) StartConversation(messages ...model.Message) (*model.Conversation, []model.Message, error) {
|
||||||
|
if len(messages) == 0 {
|
||||||
|
return nil, nil, fmt.Errorf("Must provide at least 1 message")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create new conversation
|
||||||
|
conversation := &model.Conversation{}
|
||||||
|
err := s.saveNewConversation(conversation)
|
||||||
|
if err != nil {
|
||||||
|
return nil, nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create first message
|
||||||
|
messages[0].ConversationID = conversation.ID
|
||||||
|
err = s.db.Create(&messages[0]).Error
|
||||||
|
if err != nil {
|
||||||
|
return nil, nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// Update conversation's selected root message
|
||||||
|
conversation.SelectedRoot = &messages[0]
|
||||||
|
err = s.UpdateConversation(conversation)
|
||||||
|
if err != nil {
|
||||||
|
return nil, nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// Add additional replies to conversation
|
||||||
|
if len(messages) > 1 {
|
||||||
|
newMessages, err := s.Reply(&messages[0], messages[1:]...)
|
||||||
|
if err != nil {
|
||||||
|
return nil, nil, err
|
||||||
|
}
|
||||||
|
messages = append([]model.Message{messages[0]}, newMessages...)
|
||||||
|
}
|
||||||
|
return conversation, messages, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// CloneConversation clones the given conversation and all of its root meesages
|
||||||
|
func (s *SQLStore) CloneConversation(toClone model.Conversation) (*model.Conversation, uint, error) {
|
||||||
|
rootMessages, err := s.RootMessages(toClone.ID)
|
||||||
|
if err != nil {
|
||||||
|
return nil, 0, err
|
||||||
|
}
|
||||||
|
|
||||||
|
clone := &model.Conversation{
|
||||||
|
Title: toClone.Title + " - Clone",
|
||||||
|
}
|
||||||
|
if err := s.saveNewConversation(clone); err != nil {
|
||||||
|
return nil, 0, fmt.Errorf("Could not create clone: %s", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
var errors []error
|
||||||
|
var messageCnt uint = 0
|
||||||
|
for _, root := range rootMessages {
|
||||||
|
messageCnt++
|
||||||
|
newRoot := root
|
||||||
|
newRoot.ConversationID = clone.ID
|
||||||
|
|
||||||
|
cloned, count, err := s.CloneBranch(newRoot)
|
||||||
|
if err != nil {
|
||||||
|
errors = append(errors, err)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
messageCnt += count
|
||||||
|
|
||||||
|
if root.ID == *toClone.SelectedRootID {
|
||||||
|
clone.SelectedRootID = &cloned.ID
|
||||||
|
if err := s.UpdateConversation(clone); err != nil {
|
||||||
|
errors = append(errors, fmt.Errorf("Could not set selected root on clone: %v", err))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(errors) > 0 {
|
||||||
|
return nil, 0, fmt.Errorf("Messages failed to be cloned: %v", errors)
|
||||||
|
}
|
||||||
|
|
||||||
|
return clone, messageCnt, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Reply replies to the given parentMessage with a series of messages
|
||||||
|
func (s *SQLStore) Reply(parentMessage *model.Message, messages ...model.Message) ([]model.Message, error) {
|
||||||
|
var savedMessages []model.Message
|
||||||
|
currentParent := parentMessage
|
||||||
|
|
||||||
|
err := s.db.Transaction(func(tx *gorm.DB) error {
|
||||||
|
for i := range messages {
|
||||||
|
message := messages[i]
|
||||||
|
message.ConversationID = currentParent.ConversationID
|
||||||
|
message.ParentID = ¤tParent.ID
|
||||||
|
message.ID = 0
|
||||||
|
message.CreatedAt = time.Time{}
|
||||||
|
|
||||||
|
if err := tx.Create(&message).Error; err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
// update parent selected reply
|
||||||
|
currentParent.SelectedReply = &message
|
||||||
|
if err := tx.Model(currentParent).Update("selected_reply_id", message.ID).Error; err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
savedMessages = append(savedMessages, message)
|
||||||
|
currentParent = &message
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
})
|
||||||
|
|
||||||
|
return savedMessages, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// CloneBranch returns a deep clone of the given message and its replies, returning
|
||||||
|
// a new message object. The new message will be attached to the same parent as
|
||||||
|
// the message to clone.
|
||||||
|
func (s *SQLStore) CloneBranch(messageToClone model.Message) (*model.Message, uint, error) {
|
||||||
|
newMessage := messageToClone
|
||||||
|
newMessage.ID = 0
|
||||||
|
newMessage.Replies = nil
|
||||||
|
newMessage.SelectedReplyID = nil
|
||||||
|
newMessage.SelectedReply = nil
|
||||||
|
|
||||||
|
if err := s.db.Create(&newMessage).Error; err != nil {
|
||||||
|
return nil, 0, fmt.Errorf("Could not clone message: %s", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
originalReplies, err := s.MessageReplies(messageToClone.ID)
|
||||||
|
if err != nil {
|
||||||
|
return nil, 0, fmt.Errorf("Could not fetch message %d replies: %v", messageToClone.ID, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
var replyCount uint = 0
|
||||||
|
for _, reply := range originalReplies {
|
||||||
|
replyCount++
|
||||||
|
|
||||||
|
newReply := reply
|
||||||
|
newReply.ConversationID = messageToClone.ConversationID
|
||||||
|
newReply.ParentID = &newMessage.ID
|
||||||
|
newReply.Parent = &newMessage
|
||||||
|
|
||||||
|
res, c, err := s.CloneBranch(newReply)
|
||||||
|
if err != nil {
|
||||||
|
return nil, 0, err
|
||||||
|
}
|
||||||
|
newMessage.Replies = append(newMessage.Replies, *res)
|
||||||
|
replyCount += c
|
||||||
|
|
||||||
|
if reply.ID == *messageToClone.SelectedReplyID {
|
||||||
|
newMessage.SelectedReplyID = &res.ID
|
||||||
|
if err := s.UpdateMessage(&newMessage); err != nil {
|
||||||
|
return nil, 0, fmt.Errorf("Could not update parent select reply ID: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return &newMessage, replyCount, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// PathToRoot traverses message Parent until reaching the tree root
|
||||||
|
func (s *SQLStore) PathToRoot(message *model.Message) ([]model.Message, error) {
|
||||||
|
if message == nil {
|
||||||
|
return nil, fmt.Errorf("Message is nil")
|
||||||
|
}
|
||||||
|
var path []model.Message
|
||||||
|
current := message
|
||||||
|
for {
|
||||||
|
path = append([]model.Message{*current}, path...)
|
||||||
|
if current.Parent == nil {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
|
||||||
|
var err error
|
||||||
|
current, err = s.MessageByID(*current.ParentID)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("finding parent message: %w", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return path, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// PathToLeaf traverses message SelectedReply until reaching a tree leaf
|
||||||
|
func (s *SQLStore) PathToLeaf(message *model.Message) ([]model.Message, error) {
|
||||||
|
if message == nil {
|
||||||
|
return nil, fmt.Errorf("Message is nil")
|
||||||
|
}
|
||||||
|
var path []model.Message
|
||||||
|
current := message
|
||||||
|
for {
|
||||||
|
path = append(path, *current)
|
||||||
|
if current.SelectedReplyID == nil {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
|
||||||
|
var err error
|
||||||
|
current, err = s.MessageByID(*current.SelectedReplyID)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("finding selected reply: %w", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return path, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *SQLStore) LatestConversationMessages() ([]model.Message, error) {
|
||||||
|
var latestMessages []model.Message
|
||||||
|
|
||||||
|
subQuery := s.db.Model(&model.Message{}).
|
||||||
|
Select("MAX(created_at) as max_created_at, conversation_id").
|
||||||
|
Group("conversation_id")
|
||||||
|
|
||||||
|
err := s.db.Model(&model.Message{}).
|
||||||
|
Joins("JOIN (?) as sub on messages.conversation_id = sub.conversation_id AND messages.created_at = sub.max_created_at", subQuery).
|
||||||
|
Order("created_at DESC").
|
||||||
|
Preload("Conversation").
|
||||||
|
Find(&latestMessages).Error
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
return latestMessages, nil
|
||||||
}
|
}
|
||||||
|
@ -11,7 +11,9 @@ import (
|
|||||||
toolutil "git.mlow.ca/mlow/lmcli/pkg/lmcli/tools/util"
|
toolutil "git.mlow.ca/mlow/lmcli/pkg/lmcli/tools/util"
|
||||||
)
|
)
|
||||||
|
|
||||||
const TREE_DESCRIPTION = `Retrieve a tree view of a directory's contents.
|
const TREE_DESCRIPTION = `Retrieve a tree-like view of a directory's contents.
|
||||||
|
|
||||||
|
Use these results for your own reference in completing your task, they do not need to be shown to the user.
|
||||||
|
|
||||||
Example result:
|
Example result:
|
||||||
{
|
{
|
||||||
@ -35,48 +37,45 @@ var DirTreeTool = model.Tool{
|
|||||||
Description: "If set, display the tree starting from this path relative to the current one.",
|
Description: "If set, display the tree starting from this path relative to the current one.",
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
Name: "max_depth",
|
Name: "depth",
|
||||||
Type: "integer",
|
Type: "integer",
|
||||||
Description: "Maximum depth of recursion. Default is unlimited.",
|
Description: "Depth of directory recursion. Default 0. Use -1 for unlimited.",
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
Impl: func(tool *model.Tool, args map[string]interface{}) (string, error) {
|
Impl: func(tool *model.Tool, args map[string]interface{}) (string, error) {
|
||||||
var relativeDir string
|
var relativeDir string
|
||||||
tmp, ok := args["relative_dir"]
|
if tmp, ok := args["relative_path"]; ok {
|
||||||
if ok {
|
|
||||||
relativeDir, ok = tmp.(string)
|
relativeDir, ok = tmp.(string)
|
||||||
if !ok {
|
if !ok {
|
||||||
return "", fmt.Errorf("Invalid relative_dir in function arguments: %v", tmp)
|
return "", fmt.Errorf("expected string for relative_path, got %T", tmp)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
var maxDepth int = -1
|
var depth int = 0 // Default value if not provided
|
||||||
tmp, ok = args["max_depth"]
|
if tmp, ok := args["depth"]; ok {
|
||||||
if ok {
|
switch v := tmp.(type) {
|
||||||
maxDepth, ok = tmp.(int)
|
case float64:
|
||||||
if !ok {
|
depth = int(v)
|
||||||
if tmps, ok := tmp.(string); ok {
|
case string:
|
||||||
tmpi, err := strconv.Atoi(tmps)
|
var err error
|
||||||
maxDepth = tmpi
|
if depth, err = strconv.Atoi(v); err != nil {
|
||||||
if err != nil {
|
return "", fmt.Errorf("invalid `depth` value, expected integer but got string that cannot convert: %v", tmp)
|
||||||
return "", fmt.Errorf("Invalid max_depth in function arguments: %v", tmp)
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
return "", fmt.Errorf("Invalid max_depth in function arguments: %v", tmp)
|
|
||||||
}
|
}
|
||||||
|
default:
|
||||||
|
return "", fmt.Errorf("expected int or string for max_depth, got %T", tmp)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
result := tree(relativeDir, maxDepth)
|
result := tree(relativeDir, depth)
|
||||||
ret, err := result.ToJson()
|
ret, err := result.ToJson()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return "", fmt.Errorf("Could not serialize result: %v", err)
|
return "", fmt.Errorf("could not serialize result: %v", err)
|
||||||
}
|
}
|
||||||
return ret, nil
|
return ret, nil
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
func tree(path string, maxDepth int) model.CallResult {
|
func tree(path string, depth int) model.CallResult {
|
||||||
if path == "" {
|
if path == "" {
|
||||||
path = "."
|
path = "."
|
||||||
}
|
}
|
||||||
@ -87,7 +86,7 @@ func tree(path string, maxDepth int) model.CallResult {
|
|||||||
|
|
||||||
var treeOutput strings.Builder
|
var treeOutput strings.Builder
|
||||||
treeOutput.WriteString(path + "\n")
|
treeOutput.WriteString(path + "\n")
|
||||||
err := buildTree(&treeOutput, path, "", maxDepth)
|
err := buildTree(&treeOutput, path, "", depth)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return model.CallResult{
|
return model.CallResult{
|
||||||
Message: err.Error(),
|
Message: err.Error(),
|
||||||
@ -97,7 +96,7 @@ func tree(path string, maxDepth int) model.CallResult {
|
|||||||
return model.CallResult{Result: treeOutput.String()}
|
return model.CallResult{Result: treeOutput.String()}
|
||||||
}
|
}
|
||||||
|
|
||||||
func buildTree(output *strings.Builder, path string, prefix string, maxDepth int) error {
|
func buildTree(output *strings.Builder, path string, prefix string, depth int) error {
|
||||||
files, err := os.ReadDir(path)
|
files, err := os.ReadDir(path)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
@ -124,14 +123,14 @@ func buildTree(output *strings.Builder, path string, prefix string, maxDepth int
|
|||||||
output.WriteString(prefix + branch + file.Name())
|
output.WriteString(prefix + branch + file.Name())
|
||||||
if file.IsDir() {
|
if file.IsDir() {
|
||||||
output.WriteString("/\n")
|
output.WriteString("/\n")
|
||||||
if maxDepth != 0 {
|
if depth != 0 {
|
||||||
var nextPrefix string
|
var nextPrefix string
|
||||||
if isLast {
|
if isLast {
|
||||||
nextPrefix = prefix + " "
|
nextPrefix = prefix + " "
|
||||||
} else {
|
} else {
|
||||||
nextPrefix = prefix + "│ "
|
nextPrefix = prefix + "│ "
|
||||||
}
|
}
|
||||||
buildTree(output, filepath.Join(path, file.Name()), nextPrefix, maxDepth-1)
|
buildTree(output, filepath.Join(path, file.Name()), nextPrefix, depth-1)
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
output.WriteString(sizeStr + "\n")
|
output.WriteString(sizeStr + "\n")
|
||||||
@ -140,4 +139,3 @@ func buildTree(output *strings.Builder, path string, prefix string, maxDepth int
|
|||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -9,7 +9,9 @@ import (
|
|||||||
toolutil "git.mlow.ca/mlow/lmcli/pkg/lmcli/tools/util"
|
toolutil "git.mlow.ca/mlow/lmcli/pkg/lmcli/tools/util"
|
||||||
)
|
)
|
||||||
|
|
||||||
const READ_FILE_DESCRIPTION = `Read the contents of a text file relative to the current working directory.
|
const READ_FILE_DESCRIPTION = `Retrieve the contents of a text file relative to the current working directory.
|
||||||
|
|
||||||
|
Use the file contents for your own reference in completing your task, they do not need to be shown to the user.
|
||||||
|
|
||||||
Each line of the returned content is prefixed with its line number and a tab (\t).
|
Each line of the returned content is prefixed with its line number and a tab (\t).
|
||||||
|
|
||||||
|
140
pkg/tui/chat.go
140
pkg/tui/chat.go
@ -66,6 +66,10 @@ type chatModel struct {
|
|||||||
replyChunkChan chan string
|
replyChunkChan chan string
|
||||||
persistence bool // whether we will save new messages in the conversation
|
persistence bool // whether we will save new messages in the conversation
|
||||||
|
|
||||||
|
tokenCount uint
|
||||||
|
startTime time.Time
|
||||||
|
elapsed time.Duration
|
||||||
|
|
||||||
// ui state
|
// ui state
|
||||||
focus focusState
|
focus focusState
|
||||||
wrap bool // whether message content is wrapped to viewport width
|
wrap bool // whether message content is wrapped to viewport width
|
||||||
@ -282,6 +286,9 @@ func (m chatModel) Update(msg tea.Msg) (chatModel, tea.Cmd) {
|
|||||||
}
|
}
|
||||||
m.updateContent()
|
m.updateContent()
|
||||||
cmds = append(cmds, m.waitForChunk()) // wait for the next chunk
|
cmds = append(cmds, m.waitForChunk()) // wait for the next chunk
|
||||||
|
|
||||||
|
m.tokenCount++
|
||||||
|
m.elapsed = time.Now().Sub(m.startTime)
|
||||||
case msgAssistantReply:
|
case msgAssistantReply:
|
||||||
// the last reply that was being worked on is finished
|
// the last reply that was being worked on is finished
|
||||||
reply := models.Message(msg)
|
reply := models.Message(msg)
|
||||||
@ -300,14 +307,9 @@ func (m chatModel) Update(msg tea.Msg) (chatModel, tea.Cmd) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
if m.persistence {
|
if m.persistence {
|
||||||
var err error
|
err := m.persistConversation()
|
||||||
if m.conversation.ID == 0 {
|
|
||||||
err = m.ctx.Store.SaveConversation(m.conversation)
|
|
||||||
}
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
cmds = append(cmds, wrapError(err))
|
cmds = append(cmds, wrapError(err))
|
||||||
} else {
|
|
||||||
cmds = append(cmds, m.persistConversation())
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -334,7 +336,7 @@ func (m chatModel) Update(msg tea.Msg) (chatModel, tea.Cmd) {
|
|||||||
title := string(msg)
|
title := string(msg)
|
||||||
m.conversation.Title = title
|
m.conversation.Title = title
|
||||||
if m.persistence {
|
if m.persistence {
|
||||||
err := m.ctx.Store.SaveConversation(m.conversation)
|
err := m.ctx.Store.UpdateConversation(m.conversation)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
cmds = append(cmds, wrapError(err))
|
cmds = append(cmds, wrapError(err))
|
||||||
}
|
}
|
||||||
@ -462,8 +464,8 @@ func (m *chatModel) handleInputKey(msg tea.KeyMsg) (bool, tea.Cmd) {
|
|||||||
m.input.Blur()
|
m.input.Blur()
|
||||||
return true, nil
|
return true, nil
|
||||||
case "ctrl+s":
|
case "ctrl+s":
|
||||||
userInput := strings.TrimSpace(m.input.Value())
|
input := strings.TrimSpace(m.input.Value())
|
||||||
if strings.TrimSpace(userInput) == "" {
|
if input == "" {
|
||||||
return true, nil
|
return true, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -471,35 +473,19 @@ func (m *chatModel) handleInputKey(msg tea.KeyMsg) (bool, tea.Cmd) {
|
|||||||
return true, wrapError(fmt.Errorf("Can't reply to a user message"))
|
return true, wrapError(fmt.Errorf("Can't reply to a user message"))
|
||||||
}
|
}
|
||||||
|
|
||||||
reply := models.Message{
|
m.addMessage(models.Message{
|
||||||
Role: models.MessageRoleUser,
|
Role: models.MessageRoleUser,
|
||||||
Content: userInput,
|
Content: input,
|
||||||
}
|
})
|
||||||
|
|
||||||
if m.persistence {
|
|
||||||
var err error
|
|
||||||
if m.conversation.ID == 0 {
|
|
||||||
err = m.ctx.Store.SaveConversation(m.conversation)
|
|
||||||
}
|
|
||||||
if err != nil {
|
|
||||||
return true, wrapError(err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// ensure all messages up to the one we're about to add are persisted
|
|
||||||
cmd := m.persistConversation()
|
|
||||||
if cmd != nil {
|
|
||||||
return true, cmd
|
|
||||||
}
|
|
||||||
|
|
||||||
savedReply, err := m.ctx.Store.AddReply(m.conversation, reply)
|
|
||||||
if err != nil {
|
|
||||||
return true, wrapError(err)
|
|
||||||
}
|
|
||||||
reply = *savedReply
|
|
||||||
}
|
|
||||||
|
|
||||||
m.input.SetValue("")
|
m.input.SetValue("")
|
||||||
m.addMessage(reply)
|
|
||||||
|
if m.persistence {
|
||||||
|
err := m.persistConversation()
|
||||||
|
if err != nil {
|
||||||
|
return true, wrapError(err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
m.updateContent()
|
m.updateContent()
|
||||||
m.content.GotoBottom()
|
m.content.GotoBottom()
|
||||||
@ -693,10 +679,16 @@ func (m *chatModel) footerView() string {
|
|||||||
saving,
|
saving,
|
||||||
segmentStyle.Render(status),
|
segmentStyle.Render(status),
|
||||||
}
|
}
|
||||||
rightSegments := []string{
|
rightSegments := []string{}
|
||||||
segmentStyle.Render(fmt.Sprintf("Model: %s", *m.ctx.Config.Defaults.Model)),
|
|
||||||
|
if m.elapsed > 0 && m.tokenCount > 0 {
|
||||||
|
throughput := fmt.Sprintf("%.0f t/sec", float64(m.tokenCount)/m.elapsed.Seconds())
|
||||||
|
rightSegments = append(rightSegments, segmentStyle.Render(throughput))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
model := fmt.Sprintf("Model: %s", *m.ctx.Config.Defaults.Model)
|
||||||
|
rightSegments = append(rightSegments, segmentStyle.Render(model))
|
||||||
|
|
||||||
left := strings.Join(leftSegments, segmentSeparator)
|
left := strings.Join(leftSegments, segmentSeparator)
|
||||||
right := strings.Join(rightSegments, segmentSeparator)
|
right := strings.Join(rightSegments, segmentSeparator)
|
||||||
|
|
||||||
@ -770,7 +762,7 @@ func (m *chatModel) loadConversation(shortname string) tea.Cmd {
|
|||||||
|
|
||||||
func (m *chatModel) loadMessages(c *models.Conversation) tea.Cmd {
|
func (m *chatModel) loadMessages(c *models.Conversation) tea.Cmd {
|
||||||
return func() tea.Msg {
|
return func() tea.Msg {
|
||||||
messages, err := m.ctx.Store.Messages(c)
|
messages, err := m.ctx.Store.PathToLeaf(c.SelectedRoot)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return msgError(fmt.Errorf("Could not load conversation messages: %v\n", err))
|
return msgError(fmt.Errorf("Could not load conversation messages: %v\n", err))
|
||||||
}
|
}
|
||||||
@ -778,62 +770,48 @@ func (m *chatModel) loadMessages(c *models.Conversation) tea.Cmd {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *chatModel) persistConversation() tea.Cmd {
|
func (m *chatModel) persistConversation() error {
|
||||||
existingMessages, err := m.ctx.Store.Messages(m.conversation)
|
if m.conversation.ID == 0 {
|
||||||
|
// Start a new conversation with all messages so far
|
||||||
|
c, messages, err := m.ctx.Store.StartConversation(m.messages...)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return wrapError(fmt.Errorf("Could not retrieve existing conversation messages while trying to save: %v", err))
|
return err
|
||||||
|
}
|
||||||
|
m.conversation = c
|
||||||
|
m.messages = messages
|
||||||
|
|
||||||
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
existingById := make(map[uint]*models.Message, len(existingMessages))
|
// else, we'll handle updating an existing conversation's messages
|
||||||
for _, msg := range existingMessages {
|
for i := 0; i < len(m.messages); i++ {
|
||||||
existingById[msg.ID] = &msg
|
if m.messages[i].ID > 0 {
|
||||||
}
|
// message has an ID, update its contents
|
||||||
|
// TODO: check for content/tool equality before updating?
|
||||||
currentById := make(map[uint]*models.Message, len(m.messages))
|
err := m.ctx.Store.UpdateMessage(&m.messages[i])
|
||||||
for _, msg := range m.messages {
|
|
||||||
currentById[msg.ID] = &msg
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, msg := range existingMessages {
|
|
||||||
_, ok := currentById[msg.ID]
|
|
||||||
if !ok {
|
|
||||||
err := m.ctx.Store.DeleteMessage(&msg)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return wrapError(fmt.Errorf("Failed to remove messages: %v", err))
|
return err
|
||||||
}
|
}
|
||||||
}
|
} else if i > 0 {
|
||||||
}
|
// messages is new, so add it as a reply to previous message
|
||||||
|
saved, err := m.ctx.Store.Reply(&m.messages[i-1], m.messages[i])
|
||||||
for i, msg := range m.messages {
|
|
||||||
if msg.ID > 0 {
|
|
||||||
exist, ok := existingById[msg.ID]
|
|
||||||
if ok {
|
|
||||||
if msg.Content == exist.Content {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
// update message when contents don't match that of store
|
|
||||||
err := m.ctx.Store.UpdateMessage(&msg)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return wrapError(err)
|
return err
|
||||||
}
|
}
|
||||||
|
m.messages[i] = saved[0]
|
||||||
} else {
|
} else {
|
||||||
// this would be quite odd... and I'm not sure how to handle
|
// message has no id and no previous messages to add it to
|
||||||
// it at the time of writing this
|
// this shouldn't happen?
|
||||||
}
|
return fmt.Errorf("Error: no messages to reply to")
|
||||||
} else {
|
|
||||||
newMessage, err := m.ctx.Store.AddReply(m.conversation, msg)
|
|
||||||
if err != nil {
|
|
||||||
return wrapError(err)
|
|
||||||
}
|
|
||||||
m.setMessage(i, *newMessage)
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *chatModel) generateConversationTitle() tea.Cmd {
|
func (m *chatModel) generateConversationTitle() tea.Cmd {
|
||||||
return func() tea.Msg {
|
return func() tea.Msg {
|
||||||
title, err := cmdutil.GenerateTitle(m.ctx, m.conversation)
|
title, err := cmdutil.GenerateTitle(m.ctx, m.messages)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return msgError(err)
|
return msgError(err)
|
||||||
}
|
}
|
||||||
@ -857,6 +835,10 @@ func (m *chatModel) promptLLM() tea.Cmd {
|
|||||||
m.waitingForReply = true
|
m.waitingForReply = true
|
||||||
m.status = "Press ctrl+c to cancel"
|
m.status = "Press ctrl+c to cancel"
|
||||||
|
|
||||||
|
m.tokenCount = 0
|
||||||
|
m.startTime = time.Now()
|
||||||
|
m.elapsed = 0
|
||||||
|
|
||||||
return func() tea.Msg {
|
return func() tea.Msg {
|
||||||
completionProvider, err := m.ctx.GetCompletionProvider(*m.ctx.Config.Defaults.Model)
|
completionProvider, err := m.ctx.GetCompletionProvider(*m.ctx.Config.Defaults.Model)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
@ -2,7 +2,6 @@ package tui
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"fmt"
|
"fmt"
|
||||||
"slices"
|
|
||||||
"strings"
|
"strings"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
@ -115,7 +114,7 @@ func (m *conversationsModel) handleResize(width, height int) {
|
|||||||
func (m conversationsModel) Update(msg tea.Msg) (conversationsModel, tea.Cmd) {
|
func (m conversationsModel) Update(msg tea.Msg) (conversationsModel, tea.Cmd) {
|
||||||
var cmds []tea.Cmd
|
var cmds []tea.Cmd
|
||||||
switch msg := msg.(type) {
|
switch msg := msg.(type) {
|
||||||
case msgStateChange:
|
case msgStateEnter:
|
||||||
cmds = append(cmds, m.loadConversations())
|
cmds = append(cmds, m.loadConversations())
|
||||||
m.content.SetContent(m.renderConversationList())
|
m.content.SetContent(m.renderConversationList())
|
||||||
case tea.WindowSizeMsg:
|
case tea.WindowSizeMsg:
|
||||||
@ -145,24 +144,16 @@ func (m conversationsModel) Update(msg tea.Msg) (conversationsModel, tea.Cmd) {
|
|||||||
|
|
||||||
func (m *conversationsModel) loadConversations() tea.Cmd {
|
func (m *conversationsModel) loadConversations() tea.Cmd {
|
||||||
return func() tea.Msg {
|
return func() tea.Msg {
|
||||||
conversations, err := m.ctx.Store.Conversations()
|
messages, err := m.ctx.Store.LatestConversationMessages()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return msgError(fmt.Errorf("Could not load conversations: %v", err))
|
return msgError(fmt.Errorf("Could not load conversations: %v", err))
|
||||||
}
|
}
|
||||||
|
|
||||||
loaded := make([]loadedConversation, len(conversations))
|
loaded := make([]loadedConversation, len(messages))
|
||||||
for i, c := range conversations {
|
for i, m := range messages {
|
||||||
lastMessage, err := m.ctx.Store.LastMessage(&c)
|
loaded[i].lastReply = m
|
||||||
if err != nil {
|
loaded[i].conv = m.Conversation
|
||||||
return msgError(err)
|
|
||||||
}
|
}
|
||||||
loaded[i].conv = c
|
|
||||||
loaded[i].lastReply = *lastMessage
|
|
||||||
}
|
|
||||||
|
|
||||||
slices.SortFunc(loaded, func(a, b loadedConversation) int {
|
|
||||||
return b.lastReply.CreatedAt.Compare(a.lastReply.CreatedAt)
|
|
||||||
})
|
|
||||||
|
|
||||||
return msgConversationsLoaded(loaded)
|
return msgConversationsLoaded(loaded)
|
||||||
}
|
}
|
||||||
|
@ -147,16 +147,16 @@ func SetStructDefaults(data interface{}) bool {
|
|||||||
case reflect.String:
|
case reflect.String:
|
||||||
defaultValue := defaultTag
|
defaultValue := defaultTag
|
||||||
field.Set(reflect.ValueOf(&defaultValue))
|
field.Set(reflect.ValueOf(&defaultValue))
|
||||||
|
case reflect.Uint, reflect.Uint32, reflect.Uint64:
|
||||||
|
intValue, _ := strconv.ParseUint(defaultTag, 10, e.Bits())
|
||||||
|
field.Set(reflect.New(e))
|
||||||
|
field.Elem().SetUint(intValue)
|
||||||
case reflect.Int, reflect.Int32, reflect.Int64:
|
case reflect.Int, reflect.Int32, reflect.Int64:
|
||||||
intValue, _ := strconv.ParseInt(defaultTag, 10, 64)
|
intValue, _ := strconv.ParseInt(defaultTag, 10, e.Bits())
|
||||||
field.Set(reflect.New(e))
|
field.Set(reflect.New(e))
|
||||||
field.Elem().SetInt(intValue)
|
field.Elem().SetInt(intValue)
|
||||||
case reflect.Float32:
|
case reflect.Float32, reflect.Float64:
|
||||||
floatValue, _ := strconv.ParseFloat(defaultTag, 32)
|
floatValue, _ := strconv.ParseFloat(defaultTag, e.Bits())
|
||||||
field.Set(reflect.New(e))
|
|
||||||
field.Elem().SetFloat(floatValue)
|
|
||||||
case reflect.Float64:
|
|
||||||
floatValue, _ := strconv.ParseFloat(defaultTag, 64)
|
|
||||||
field.Set(reflect.New(e))
|
field.Set(reflect.New(e))
|
||||||
field.Elem().SetFloat(floatValue)
|
field.Elem().SetFloat(floatValue)
|
||||||
case reflect.Bool:
|
case reflect.Bool:
|
||||||
|
Loading…
Reference in New Issue
Block a user