Compare commits

..

No commits in common. "db465f1bf05c747e065ad84ee7daa37547af442c" and "aeeb7bb7f79ccc5a1380b47a54e4b33559381101" have entirely different histories.

24 changed files with 348 additions and 1061 deletions

View File

@ -3,7 +3,6 @@ package cmd
import ( import (
"fmt" "fmt"
cmdutil "git.mlow.ca/mlow/lmcli/pkg/cmd/util"
"git.mlow.ca/mlow/lmcli/pkg/lmcli" "git.mlow.ca/mlow/lmcli/pkg/lmcli"
"git.mlow.ca/mlow/lmcli/pkg/tui" "git.mlow.ca/mlow/lmcli/pkg/tui"
"github.com/spf13/cobra" "github.com/spf13/cobra"
@ -15,16 +14,11 @@ func ChatCmd(ctx *lmcli.Context) *cobra.Command {
Short: "Open the chat interface", Short: "Open the chat interface",
Long: `Open the chat interface, optionally on a given conversation.`, Long: `Open the chat interface, optionally on a given conversation.`,
RunE: func(cmd *cobra.Command, args []string) error { RunE: func(cmd *cobra.Command, args []string) error {
// TODO: implement jump-to-conversation logic
shortname := "" shortname := ""
if len(args) == 1 { if len(args) == 1 {
shortname = args[0] shortname = args[0]
} }
if shortname != ""{
_, err := cmdutil.LookupConversationE(ctx, shortname)
if err != nil {
return err
}
}
err := tui.Launch(ctx, shortname) err := tui.Launch(ctx, shortname)
if err != nil { if err != nil {
return fmt.Errorf("Error fetching LLM response: %v", err) return fmt.Errorf("Error fetching LLM response: %v", err)

View File

@ -5,6 +5,7 @@ import (
cmdutil "git.mlow.ca/mlow/lmcli/pkg/cmd/util" cmdutil "git.mlow.ca/mlow/lmcli/pkg/cmd/util"
"git.mlow.ca/mlow/lmcli/pkg/lmcli" "git.mlow.ca/mlow/lmcli/pkg/lmcli"
"git.mlow.ca/mlow/lmcli/pkg/lmcli/model"
"github.com/spf13/cobra" "github.com/spf13/cobra"
) )
@ -27,12 +28,36 @@ func CloneCmd(ctx *lmcli.Context) *cobra.Command {
return err return err
} }
clone, messageCnt, err := ctx.Store.CloneConversation(*toClone) messagesToCopy, err := ctx.Store.Messages(toClone)
if err != nil { if err != nil {
return fmt.Errorf("Failed to clone conversation: %v", err) return fmt.Errorf("Could not retrieve messages for conversation: %s", toClone.ShortName.String)
} }
fmt.Printf("Cloned %d messages to: %s - %s\n", messageCnt, clone.ShortName.String, clone.Title) clone := &model.Conversation{
Title: toClone.Title + " - Clone",
}
if err := ctx.Store.SaveConversation(clone); err != nil {
return fmt.Errorf("Cloud not create clone: %s", err)
}
var errors []error
messageCnt := 0
for _, message := range messagesToCopy {
newMessage := message
newMessage.ConversationID = clone.ID
newMessage.ID = 0
if err := ctx.Store.SaveMessage(&newMessage); err != nil {
errors = append(errors, err)
} else {
messageCnt++
}
}
if len(errors) > 0 {
return fmt.Errorf("Messages failed to be cloned: %v", errors)
}
fmt.Printf("Cloned %d messages to: %s\n", messageCnt, clone.Title)
return nil return nil
}, },
ValidArgsFunction: func(cmd *cobra.Command, args []string, toComplete string) ([]string, cobra.ShellCompDirective) { ValidArgsFunction: func(cmd *cobra.Command, args []string, toComplete string) ([]string, cobra.ShellCompDirective) {

View File

@ -26,7 +26,7 @@ func ContinueCmd(ctx *lmcli.Context) *cobra.Command {
shortName := args[0] shortName := args[0]
conversation := cmdutil.LookupConversation(ctx, shortName) conversation := cmdutil.LookupConversation(ctx, shortName)
messages, err := ctx.Store.PathToLeaf(conversation.SelectedRoot) messages, err := ctx.Store.Messages(conversation)
if err != nil { if err != nil {
return fmt.Errorf("could not retrieve conversation messages: %v", err) return fmt.Errorf("could not retrieve conversation messages: %v", err)
} }

View File

@ -24,7 +24,7 @@ func EditCmd(ctx *lmcli.Context) *cobra.Command {
shortName := args[0] shortName := args[0]
conversation := cmdutil.LookupConversation(ctx, shortName) conversation := cmdutil.LookupConversation(ctx, shortName)
messages, err := ctx.Store.PathToLeaf(conversation.SelectedRoot) messages, err := ctx.Store.Messages(conversation)
if err != nil { if err != nil {
return fmt.Errorf("Could not retrieve messages for conversation: %s", conversation.Title) return fmt.Errorf("Could not retrieve messages for conversation: %s", conversation.Title)
} }
@ -39,7 +39,21 @@ func EditCmd(ctx *lmcli.Context) *cobra.Command {
} }
desiredIdx := len(messages) - 1 - offset desiredIdx := len(messages) - 1 - offset
toEdit := messages[desiredIdx]
// walk backwards through the conversation deleting messages until and
// including the last user message
toRemove := []model.Message{}
var toEdit *model.Message
for i := len(messages) - 1; i >= 0; i-- {
if i == desiredIdx {
toEdit = &messages[i]
}
toRemove = append(toRemove, messages[i])
messages = messages[:i]
if toEdit != nil {
break
}
}
newContents := inputFromArgsOrEditor(args[1:], "# Save when finished editing\n", toEdit.Content) newContents := inputFromArgsOrEditor(args[1:], "# Save when finished editing\n", toEdit.Content)
switch newContents { switch newContents {
@ -49,17 +63,26 @@ func EditCmd(ctx *lmcli.Context) *cobra.Command {
return fmt.Errorf("No message was provided.") return fmt.Errorf("No message was provided.")
} }
toEdit.Content = newContents
role, _ := cmd.Flags().GetString("role") role, _ := cmd.Flags().GetString("role")
if role != "" { if role == "" {
if role != string(model.MessageRoleUser) && role != string(model.MessageRoleAssistant) { role = string(toEdit.Role)
} else if role != string(model.MessageRoleUser) && role != string(model.MessageRoleAssistant) {
return fmt.Errorf("Invalid role specified. Please use 'user' or 'assistant'.") return fmt.Errorf("Invalid role specified. Please use 'user' or 'assistant'.")
} }
toEdit.Role = model.MessageRole(role)
for _, message := range toRemove {
err = ctx.Store.DeleteMessage(&message)
if err != nil {
lmcli.Warn("Could not delete message: %v\n", err)
}
} }
return ctx.Store.UpdateMessage(&toEdit) cmdutil.HandleConversationReply(ctx, conversation, true, model.Message{
ConversationID: conversation.ID,
Role: model.MessageRole(role),
Content: newContents,
})
return nil
}, },
ValidArgsFunction: func(cmd *cobra.Command, args []string, toComplete string) ([]string, cobra.ShellCompDirective) { ValidArgsFunction: func(cmd *cobra.Command, args []string, toComplete string) ([]string, cobra.ShellCompDirective) {
compMode := cobra.ShellCompDirectiveNoFileComp compMode := cobra.ShellCompDirectiveNoFileComp

View File

@ -2,6 +2,7 @@ package cmd
import ( import (
"fmt" "fmt"
"slices"
"time" "time"
"git.mlow.ca/mlow/lmcli/pkg/lmcli" "git.mlow.ca/mlow/lmcli/pkg/lmcli"
@ -20,7 +21,7 @@ func ListCmd(ctx *lmcli.Context) *cobra.Command {
Short: "List conversations", Short: "List conversations",
Long: `List conversations in order of recent activity`, Long: `List conversations in order of recent activity`,
RunE: func(cmd *cobra.Command, args []string) error { RunE: func(cmd *cobra.Command, args []string) error {
messages, err := ctx.Store.LatestConversationMessages() conversations, err := ctx.Store.Conversations()
if err != nil { if err != nil {
return fmt.Errorf("Could not fetch conversations: %v", err) return fmt.Errorf("Could not fetch conversations: %v", err)
} }
@ -57,8 +58,13 @@ func ListCmd(ctx *lmcli.Context) *cobra.Command {
all, _ := cmd.Flags().GetBool("all") all, _ := cmd.Flags().GetBool("all")
for _, message := range messages { for _, conversation := range conversations {
messageAge := now.Sub(message.CreatedAt) lastMessage, err := ctx.Store.LastMessage(&conversation)
if lastMessage == nil || err != nil {
continue
}
messageAge := now.Sub(lastMessage.CreatedAt)
var category string var category string
for _, c := range categories { for _, c := range categories {
@ -70,9 +76,9 @@ func ListCmd(ctx *lmcli.Context) *cobra.Command {
formatted := fmt.Sprintf( formatted := fmt.Sprintf(
"%s - %s - %s", "%s - %s - %s",
message.Conversation.ShortName.String, conversation.ShortName.String,
util.HumanTimeElapsedSince(messageAge), util.HumanTimeElapsedSince(messageAge),
message.Conversation.Title, conversation.Title,
) )
categorized[category] = append( categorized[category] = append(
@ -90,10 +96,14 @@ func ListCmd(ctx *lmcli.Context) *cobra.Command {
continue continue
} }
slices.SortFunc(conversationLines, func(a, b ConversationLine) int {
return int(a.timeSinceReply - b.timeSinceReply)
})
fmt.Printf("%s:\n", category.name) fmt.Printf("%s:\n", category.name)
for _, conv := range conversationLines { for _, conv := range conversationLines {
if conversationsPrinted >= count && !all { if conversationsPrinted >= count && !all {
fmt.Printf("%d remaining conversation(s), use --all to view.\n", len(messages)-conversationsPrinted) fmt.Printf("%d remaining message(s), use --all to view.\n", len(conversations)-conversationsPrinted)
break outer break outer
} }

View File

@ -15,43 +15,42 @@ func NewCmd(ctx *lmcli.Context) *cobra.Command {
Short: "Start a new conversation", Short: "Start a new conversation",
Long: `Start a new conversation with the Large Language Model.`, Long: `Start a new conversation with the Large Language Model.`,
RunE: func(cmd *cobra.Command, args []string) error { RunE: func(cmd *cobra.Command, args []string) error {
input := inputFromArgsOrEditor(args, "# Start a new conversation below\n", "") messageContents := inputFromArgsOrEditor(args, "# What would you like to say?\n", "")
if input == "" { if messageContents == "" {
return fmt.Errorf("No message was provided.") return fmt.Errorf("No message was provided.")
} }
var messages []model.Message conversation := &model.Conversation{}
err := ctx.Store.SaveConversation(conversation)
if err != nil {
return fmt.Errorf("Could not save new conversation: %v", err)
}
// TODO: probably just make this part of the conversation messages := []model.Message{
system := ctx.GetSystemPrompt() {
if system != "" { ConversationID: conversation.ID,
messages = append(messages, model.Message{
Role: model.MessageRoleSystem, Role: model.MessageRoleSystem,
Content: system, Content: ctx.GetSystemPrompt(),
}) },
} {
ConversationID: conversation.ID,
messages = append(messages, model.Message{
Role: model.MessageRoleUser, Role: model.MessageRoleUser,
Content: input, Content: messageContents,
}) },
conversation, messages, err := ctx.Store.StartConversation(messages...)
if err != nil {
return fmt.Errorf("Could not start a new conversation: %v", err)
} }
cmdutil.HandleReply(ctx, &messages[len(messages)-1], true) cmdutil.HandleConversationReply(ctx, conversation, true, messages...)
title, err := cmdutil.GenerateTitle(ctx, messages) title, err := cmdutil.GenerateTitle(ctx, conversation)
if err != nil { if err != nil {
lmcli.Warn("Could not generate title for conversation %s: %v\n", conversation.ShortName.String, err) lmcli.Warn("Could not generate title for conversation: %v\n", err)
} }
conversation.Title = title conversation.Title = title
err = ctx.Store.UpdateConversation(conversation)
err = ctx.Store.SaveConversation(conversation)
if err != nil { if err != nil {
lmcli.Warn("Could not save conversation title: %v\n", err) lmcli.Warn("Could not save conversation after generating title: %v\n", err)
} }
return nil return nil
}, },

View File

@ -15,26 +15,21 @@ func PromptCmd(ctx *lmcli.Context) *cobra.Command {
Short: "Do a one-shot prompt", Short: "Do a one-shot prompt",
Long: `Prompt the Large Language Model and get a response.`, Long: `Prompt the Large Language Model and get a response.`,
RunE: func(cmd *cobra.Command, args []string) error { RunE: func(cmd *cobra.Command, args []string) error {
input := inputFromArgsOrEditor(args, "# Write your prompt below\n", "") message := inputFromArgsOrEditor(args, "# What would you like to say?\n", "")
if input == "" { if message == "" {
return fmt.Errorf("No message was provided.") return fmt.Errorf("No message was provided.")
} }
var messages []model.Message messages := []model.Message{
{
// TODO: stop supplying system prompt as a message
system := ctx.GetSystemPrompt()
if system != "" {
messages = append(messages, model.Message{
Role: model.MessageRoleSystem, Role: model.MessageRoleSystem,
Content: system, Content: ctx.GetSystemPrompt(),
}) },
} {
messages = append(messages, model.Message{
Role: model.MessageRoleUser, Role: model.MessageRoleUser,
Content: input, Content: message,
}) },
}
_, err := cmdutil.Prompt(ctx, messages, nil) _, err := cmdutil.Prompt(ctx, messages, nil)
if err != nil { if err != nil {

View File

@ -24,17 +24,12 @@ func RenameCmd(ctx *lmcli.Context) *cobra.Command {
RunE: func(cmd *cobra.Command, args []string) error { RunE: func(cmd *cobra.Command, args []string) error {
shortName := args[0] shortName := args[0]
conversation := cmdutil.LookupConversation(ctx, shortName) conversation := cmdutil.LookupConversation(ctx, shortName)
var err error var err error
var title string
generate, _ := cmd.Flags().GetBool("generate") generate, _ := cmd.Flags().GetBool("generate")
var title string
if generate { if generate {
messages, err := ctx.Store.PathToLeaf(conversation.SelectedRoot) title, err = cmdutil.GenerateTitle(ctx, conversation)
if err != nil {
return fmt.Errorf("Could not retrieve conversation messages: %v", err)
}
title, err = cmdutil.GenerateTitle(ctx, messages)
if err != nil { if err != nil {
return fmt.Errorf("Could not generate conversation title: %v", err) return fmt.Errorf("Could not generate conversation title: %v", err)
} }
@ -46,9 +41,9 @@ func RenameCmd(ctx *lmcli.Context) *cobra.Command {
} }
conversation.Title = title conversation.Title = title
err = ctx.Store.UpdateConversation(conversation) err = ctx.Store.SaveConversation(conversation)
if err != nil { if err != nil {
lmcli.Warn("Could not update conversation title: %v\n", err) lmcli.Warn("Could not save conversation with new title: %v\n", err)
} }
return nil return nil
}, },

View File

@ -13,7 +13,7 @@ func RetryCmd(ctx *lmcli.Context) *cobra.Command {
cmd := &cobra.Command{ cmd := &cobra.Command{
Use: "retry <conversation>", Use: "retry <conversation>",
Short: "Retry the last user reply in a conversation", Short: "Retry the last user reply in a conversation",
Long: `Prompt the conversation from the last user response.`, Long: `Re-prompt the conversation up to the last user response. Can be used to regenerate the last assistant reply, or simply generate one if an error occurred.`,
Args: func(cmd *cobra.Command, args []string) error { Args: func(cmd *cobra.Command, args []string) error {
argCount := 1 argCount := 1
if err := cobra.MinimumNArgs(argCount)(cmd, args); err != nil { if err := cobra.MinimumNArgs(argCount)(cmd, args); err != nil {
@ -25,28 +25,25 @@ func RetryCmd(ctx *lmcli.Context) *cobra.Command {
shortName := args[0] shortName := args[0]
conversation := cmdutil.LookupConversation(ctx, shortName) conversation := cmdutil.LookupConversation(ctx, shortName)
// Load the complete thread from the root message messages, err := ctx.Store.Messages(conversation)
messages, err := ctx.Store.PathToLeaf(conversation.SelectedRoot)
if err != nil { if err != nil {
return fmt.Errorf("Could not retrieve messages for conversation: %s", conversation.Title) return fmt.Errorf("Could not retrieve messages for conversation: %s", conversation.Title)
} }
// Find the last user message in the conversation // walk backwards through the conversation and delete messages, break
var lastUserMessage *model.Message // when we find the latest user response
var i int for i := len(messages) - 1; i >= 0; i-- {
for i = len(messages) - 1; i >= 0; i-- {
if messages[i].Role == model.MessageRoleUser { if messages[i].Role == model.MessageRoleUser {
lastUserMessage = &messages[i]
break break
} }
err = ctx.Store.DeleteMessage(&messages[i])
if err != nil {
lmcli.Warn("Could not delete previous reply: %v\n", err)
}
} }
if lastUserMessage == nil { cmdutil.HandleConversationReply(ctx, conversation, true)
return fmt.Errorf("No user message found in the conversation: %s", conversation.Title)
}
// Start a new branch at the last user message
cmdutil.HandleReply(ctx, lastUserMessage, true)
return nil return nil
}, },
ValidArgsFunction: func(cmd *cobra.Command, args []string, toComplete string) ([]string, cobra.ShellCompDirective) { ValidArgsFunction: func(cmd *cobra.Command, args []string, toComplete string) ([]string, cobra.ShellCompDirective) {

View File

@ -57,7 +57,7 @@ func LookupConversation(ctx *lmcli.Context, shortName string) *model.Conversatio
lmcli.Fatal("Could not lookup conversation: %v\n", err) lmcli.Fatal("Could not lookup conversation: %v\n", err)
} }
if c.ID == 0 { if c.ID == 0 {
lmcli.Fatal("Conversation not found: %s\n", shortName) lmcli.Fatal("Conversation not found with short name: %s\n", shortName)
} }
return c return c
} }
@ -68,63 +68,48 @@ func LookupConversationE(ctx *lmcli.Context, shortName string) (*model.Conversat
return nil, fmt.Errorf("Could not lookup conversation: %v", err) return nil, fmt.Errorf("Could not lookup conversation: %v", err)
} }
if c.ID == 0 { if c.ID == 0 {
return nil, fmt.Errorf("Conversation not found: %s", shortName) return nil, fmt.Errorf("Conversation not found with short name: %s", shortName)
} }
return c, nil return c, nil
} }
func HandleConversationReply(ctx *lmcli.Context, c *model.Conversation, persist bool, toSend ...model.Message) {
messages, err := ctx.Store.PathToLeaf(c.SelectedRoot)
if err != nil {
lmcli.Fatal("Could not load messages: %v\n", err)
}
HandleReply(ctx, &messages[len(messages)-1], persist, toSend...)
}
// handleConversationReply handles sending messages to an existing // handleConversationReply handles sending messages to an existing
// conversation, optionally persisting both the sent replies and responses. // conversation, optionally persisting both the sent replies and responses.
func HandleReply(ctx *lmcli.Context, to *model.Message, persist bool, messages ...model.Message) { func HandleConversationReply(ctx *lmcli.Context, c *model.Conversation, persist bool, toSend ...model.Message) {
if to == nil { existing, err := ctx.Store.Messages(c)
lmcli.Fatal("Can't prompt from an empty message.")
}
existing, err := ctx.Store.PathToRoot(to)
if err != nil { if err != nil {
lmcli.Fatal("Could not load messages: %v\n", err) lmcli.Fatal("Could not retrieve messages for conversation: %s\n", c.Title)
} }
RenderConversation(ctx, append(existing, messages...), true) if persist {
for _, message := range toSend {
var savedReplies []model.Message err = ctx.Store.SaveMessage(&message)
if persist && len(messages) > 0 {
savedReplies, err = ctx.Store.Reply(to, messages...)
if err != nil { if err != nil {
lmcli.Warn("Could not save messages: %v\n", err) lmcli.Warn("Could not save %s message: %v\n", message.Role, err)
} }
} }
}
allMessages := append(existing, toSend...)
RenderConversation(ctx, allMessages, true)
// render a message header with no contents // render a message header with no contents
RenderMessage(ctx, (&model.Message{Role: model.MessageRoleAssistant})) RenderMessage(ctx, (&model.Message{Role: model.MessageRoleAssistant}))
var lastMessage *model.Message
lastMessage = to
if len(savedReplies) > 1 {
lastMessage = &savedReplies[len(savedReplies)-1]
}
replyCallback := func(reply model.Message) { replyCallback := func(reply model.Message) {
if !persist { if !persist {
return return
} }
savedReplies, err = ctx.Store.Reply(lastMessage, reply)
reply.ConversationID = c.ID
err = ctx.Store.SaveMessage(&reply)
if err != nil { if err != nil {
lmcli.Warn("Could not save reply: %v\n", err) lmcli.Warn("Could not save reply: %v\n", err)
} }
lastMessage = &savedReplies[0]
} }
_, err = Prompt(ctx, append(existing, messages...), replyCallback) _, err = Prompt(ctx, allMessages, replyCallback)
if err != nil { if err != nil {
lmcli.Fatal("Error fetching LLM response: %v\n", err) lmcli.Fatal("Error fetching LLM response: %v\n", err)
} }
@ -149,7 +134,12 @@ func FormatForExternalPrompt(messages []model.Message, system bool) string {
return sb.String() return sb.String()
} }
func GenerateTitle(ctx *lmcli.Context, messages []model.Message) (string, error) { func GenerateTitle(ctx *lmcli.Context, c *model.Conversation) (string, error) {
messages, err := ctx.Store.Messages(c)
if err != nil {
return "", err
}
const prompt = `Above is an excerpt from a conversation between a user and AI assistant. Please reply with a short title (no more than 8 words) that reflects the topic of the conversation, read from the user's perspective. const prompt = `Above is an excerpt from a conversation between a user and AI assistant. Please reply with a short title (no more than 8 words) that reflects the topic of the conversation, read from the user's perspective.
Example conversation: Example conversation:
@ -196,7 +186,6 @@ Title: A brief introduction
response = strings.TrimPrefix(response, "Title: ") response = strings.TrimPrefix(response, "Title: ")
response = strings.Trim(response, "\"") response = strings.Trim(response, "\"")
response = strings.TrimSpace(response)
return response, nil return response, nil
} }

View File

@ -24,9 +24,9 @@ func ViewCmd(ctx *lmcli.Context) *cobra.Command {
shortName := args[0] shortName := args[0]
conversation := cmdutil.LookupConversation(ctx, shortName) conversation := cmdutil.LookupConversation(ctx, shortName)
messages, err := ctx.Store.PathToLeaf(conversation.SelectedRoot) messages, err := ctx.Store.Messages(conversation)
if err != nil { if err != nil {
return fmt.Errorf("Could not retrieve messages for conversation %s: %v", conversation.ShortName.String, err) return fmt.Errorf("Could not retrieve messages for conversation: %s", conversation.Title)
} }
cmdutil.RenderConversation(ctx, messages, false) cmdutil.RenderConversation(ctx, messages, false)

View File

@ -22,7 +22,6 @@ type Config struct {
EnabledTools []string `yaml:"enabledTools"` EnabledTools []string `yaml:"enabledTools"`
} `yaml:"tools"` } `yaml:"tools"`
Providers []*struct { Providers []*struct {
Name *string `yaml:"name"`
Kind *string `yaml:"kind"` Kind *string `yaml:"kind"`
BaseURL *string `yaml:"baseUrl"` BaseURL *string `yaml:"baseUrl"`
APIKey *string `yaml:"apiKey"` APIKey *string `yaml:"apiKey"`

View File

@ -4,12 +4,10 @@ import (
"fmt" "fmt"
"os" "os"
"path/filepath" "path/filepath"
"strings"
"git.mlow.ca/mlow/lmcli/pkg/lmcli/model" "git.mlow.ca/mlow/lmcli/pkg/lmcli/model"
"git.mlow.ca/mlow/lmcli/pkg/lmcli/provider" "git.mlow.ca/mlow/lmcli/pkg/lmcli/provider"
"git.mlow.ca/mlow/lmcli/pkg/lmcli/provider/anthropic" "git.mlow.ca/mlow/lmcli/pkg/lmcli/provider/anthropic"
"git.mlow.ca/mlow/lmcli/pkg/lmcli/provider/google"
"git.mlow.ca/mlow/lmcli/pkg/lmcli/provider/openai" "git.mlow.ca/mlow/lmcli/pkg/lmcli/provider/openai"
"git.mlow.ca/mlow/lmcli/pkg/lmcli/tools" "git.mlow.ca/mlow/lmcli/pkg/lmcli/tools"
"git.mlow.ca/mlow/lmcli/pkg/util" "git.mlow.ca/mlow/lmcli/pkg/util"
@ -36,9 +34,7 @@ func NewContext() (*Context, error) {
} }
databaseFile := filepath.Join(dataDir(), "conversations.db") databaseFile := filepath.Join(dataDir(), "conversations.db")
db, err := gorm.Open(sqlite.Open(databaseFile), &gorm.Config{ db, err := gorm.Open(sqlite.Open(databaseFile), &gorm.Config{})
//Logger: logger.Default.LogMode(logger.Info),
})
if err != nil { if err != nil {
return nil, fmt.Errorf("Error establishing connection to store: %v", err) return nil, fmt.Errorf("Error establishing connection to store: %v", err)
} }
@ -61,37 +57,16 @@ func NewContext() (*Context, error) {
} }
func (c *Context) GetModels() (models []string) { func (c *Context) GetModels() (models []string) {
modelCounts := make(map[string]int)
for _, p := range c.Config.Providers { for _, p := range c.Config.Providers {
for _, m := range *p.Models { for _, m := range *p.Models {
modelCounts[m]++
models = append(models, *p.Name+"/"+m)
}
}
for m, c := range modelCounts {
if c == 1 {
models = append(models, m) models = append(models, m)
} }
} }
return return
} }
func (c *Context) GetCompletionProvider(model string) (provider.ChatCompletionClient, error) { func (c *Context) GetCompletionProvider(model string) (provider.ChatCompletionClient, error) {
parts := strings.Split(model, "/")
var provider string
if len(parts) > 1 {
provider = parts[0]
model = parts[1]
}
for _, p := range c.Config.Providers { for _, p := range c.Config.Providers {
if provider != "" && *p.Name != provider {
continue
}
for _, m := range *p.Models { for _, m := range *p.Models {
if m == model { if m == model {
switch *p.Kind { switch *p.Kind {
@ -100,28 +75,21 @@ func (c *Context) GetCompletionProvider(model string) (provider.ChatCompletionCl
if p.BaseURL != nil { if p.BaseURL != nil {
url = *p.BaseURL url = *p.BaseURL
} }
return &anthropic.AnthropicClient{ anthropic := &anthropic.AnthropicClient{
BaseURL: url, BaseURL: url,
APIKey: *p.APIKey, APIKey: *p.APIKey,
}, nil
case "google":
url := "https://generativelanguage.googleapis.com"
if p.BaseURL != nil {
url = *p.BaseURL
} }
return &google.Client{ return anthropic, nil
BaseURL: url,
APIKey: *p.APIKey,
}, nil
case "openai": case "openai":
url := "https://api.openai.com/v1" url := "https://api.openai.com/v1"
if p.BaseURL != nil { if p.BaseURL != nil {
url = *p.BaseURL url = *p.BaseURL
} }
return &openai.OpenAIClient{ openai := &openai.OpenAIClient{
BaseURL: url, BaseURL: url,
APIKey: *p.APIKey, APIKey: *p.APIKey,
}, nil }
return openai, nil
default: default:
return nil, fmt.Errorf("unknown provider kind: %s", *p.Kind) return nil, fmt.Errorf("unknown provider kind: %s", *p.Kind)
} }

View File

@ -17,27 +17,18 @@ const (
type Message struct { type Message struct {
ID uint `gorm:"primaryKey"` ID uint `gorm:"primaryKey"`
ConversationID uint `gorm:"index"` ConversationID uint `gorm:"foreignKey:ConversationID"`
Conversation Conversation `gorm:"foreignKey:ConversationID"`
Content string Content string
Role MessageRole Role MessageRole
CreatedAt time.Time `gorm:"index"` CreatedAt time.Time
ToolCalls ToolCalls // a json array of tool calls (from the model) ToolCalls ToolCalls // a json array of tool calls (from the modl)
ToolResults ToolResults // a json array of tool results ToolResults ToolResults // a json array of tool results
ParentID *uint `gorm:"index"`
Parent *Message `gorm:"foreignKey:ParentID"`
Replies []Message `gorm:"foreignKey:ParentID"`
SelectedReplyID *uint
SelectedReply *Message `gorm:"foreignKey:SelectedReplyID"`
} }
type Conversation struct { type Conversation struct {
ID uint `gorm:"primaryKey"` ID uint `gorm:"primaryKey"`
ShortName sql.NullString ShortName sql.NullString
Title string Title string
SelectedRootID *uint
SelectedRoot *Message `gorm:"foreignKey:SelectedRootID"`
} }
type RequestParameters struct { type RequestParameters struct {

View File

@ -1,413 +0,0 @@
package google
import (
"bufio"
"bytes"
"context"
"encoding/json"
"fmt"
"io"
"net/http"
"strings"
"git.mlow.ca/mlow/lmcli/pkg/lmcli/model"
"git.mlow.ca/mlow/lmcli/pkg/lmcli/provider"
"git.mlow.ca/mlow/lmcli/pkg/lmcli/tools"
)
func convertTools(tools []model.Tool) []Tool {
geminiTools := make([]Tool, len(tools))
for i, tool := range tools {
params := make(map[string]ToolParameter)
var required []string
for _, param := range tool.Parameters {
// TODO: proper enum handing
params[param.Name] = ToolParameter{
Type: param.Type,
Description: param.Description,
Values: param.Enum,
}
if param.Required {
required = append(required, param.Name)
}
}
geminiTools[i] = Tool{
FunctionDeclarations: []FunctionDeclaration{
{
Name: tool.Name,
Description: tool.Description,
Parameters: ToolParameters{
Type: "OBJECT",
Properties: params,
Required: required,
},
},
},
}
}
return geminiTools
}
func convertToolCallToGemini(toolCalls []model.ToolCall) []ContentPart {
converted := make([]ContentPart, len(toolCalls))
for i, call := range toolCalls {
args := make(map[string]string)
for k, v := range call.Parameters {
args[k] = fmt.Sprintf("%v", v)
}
converted[i].FunctionCall = &FunctionCall{
Name: call.Name,
Args: args,
}
}
return converted
}
func convertToolCallToAPI(functionCalls []FunctionCall) []model.ToolCall {
converted := make([]model.ToolCall, len(functionCalls))
for i, call := range functionCalls {
params := make(map[string]interface{})
for k, v := range call.Args {
params[k] = v
}
converted[i].Name = call.Name
converted[i].Parameters = params
}
return converted
}
func convertToolResultsToGemini(toolResults []model.ToolResult) ([]FunctionResponse, error) {
results := make([]FunctionResponse, len(toolResults))
for i, result := range toolResults {
var obj interface{}
err := json.Unmarshal([]byte(result.Result), &obj)
if err != nil {
return nil, fmt.Errorf("Could not unmarshal %s: %v", result.Result, err)
}
results[i] = FunctionResponse{
Name: result.ToolName,
Response: obj,
}
}
return results, nil
}
func createGenerateContentRequest(
params model.RequestParameters,
messages []model.Message,
) (*GenerateContentRequest, error) {
requestContents := make([]Content, 0, len(messages))
startIdx := 0
var system string
if len(messages) > 0 && messages[0].Role == model.MessageRoleSystem {
system = messages[0].Content
startIdx = 1
}
for _, m := range messages[startIdx:] {
switch m.Role {
case "tool_call":
content := Content{
Role: "model",
Parts: convertToolCallToGemini(m.ToolCalls),
}
requestContents = append(requestContents, content)
case "tool_result":
results, err := convertToolResultsToGemini(m.ToolResults)
if err != nil {
return nil, err
}
// expand tool_result messages' results into multiple gemini messages
for _, result := range results {
content := Content{
Role: "function",
Parts: []ContentPart{
{
FunctionResp: &result,
},
},
}
requestContents = append(requestContents, content)
}
default:
var role string
switch m.Role {
case model.MessageRoleAssistant:
role = "model"
case model.MessageRoleUser:
role = "user"
}
if role == "" {
panic("Unhandled role: " + m.Role)
}
content := Content{
Role: role,
Parts: []ContentPart{
{
Text: m.Content,
},
},
}
requestContents = append(requestContents, content)
}
}
request := &GenerateContentRequest{
Contents: requestContents,
SystemInstructions: system,
GenerationConfig: &GenerationConfig{
MaxOutputTokens: &params.MaxTokens,
Temperature: &params.Temperature,
TopP: &params.TopP,
},
}
if len(params.ToolBag) > 0 {
request.Tools = convertTools(params.ToolBag)
}
return request, nil
}
func handleToolCalls(
params model.RequestParameters,
content string,
toolCalls []model.ToolCall,
callback provider.ReplyCallback,
messages []model.Message,
) ([]model.Message, error) {
lastMessage := messages[len(messages)-1]
continuation := false
if lastMessage.Role.IsAssistant() {
continuation = true
}
toolCall := model.Message{
Role: model.MessageRoleToolCall,
Content: content,
ToolCalls: toolCalls,
}
toolResults, err := tools.ExecuteToolCalls(toolCall.ToolCalls, params.ToolBag)
if err != nil {
return nil, err
}
toolResult := model.Message{
Role: model.MessageRoleToolResult,
ToolResults: toolResults,
}
if callback != nil {
callback(toolCall)
callback(toolResult)
}
if continuation {
messages[len(messages)-1] = toolCall
} else {
messages = append(messages, toolCall)
}
messages = append(messages, toolResult)
return messages, nil
}
func (c *Client) sendRequest(ctx context.Context, req *http.Request) (*http.Response, error) {
req.Header.Set("Content-Type", "application/json")
client := &http.Client{}
resp, err := client.Do(req.WithContext(ctx))
if resp.StatusCode != 200 {
bytes, _ := io.ReadAll(resp.Body)
return resp, fmt.Errorf("%v", string(bytes))
}
return resp, err
}
func (c *Client) CreateChatCompletion(
ctx context.Context,
params model.RequestParameters,
messages []model.Message,
callback provider.ReplyCallback,
) (string, error) {
if len(messages) == 0 {
return "", fmt.Errorf("Can't create completion from no messages")
}
req, err := createGenerateContentRequest(params, messages)
if err != nil {
return "", err
}
jsonData, err := json.Marshal(req)
if err != nil {
return "", err
}
url := fmt.Sprintf(
"%s/v1beta/models/%s:generateContent?key=%s",
c.BaseURL, params.Model, c.APIKey,
)
httpReq, err := http.NewRequest("POST", url, bytes.NewBuffer(jsonData))
if err != nil {
return "", err
}
resp, err := c.sendRequest(ctx, httpReq)
if err != nil {
return "", err
}
defer resp.Body.Close()
var completionResp GenerateContentResponse
err = json.NewDecoder(resp.Body).Decode(&completionResp)
if err != nil {
return "", err
}
choice := completionResp.Candidates[0]
var content string
lastMessage := messages[len(messages)-1]
if lastMessage.Role.IsAssistant() {
content = lastMessage.Content
}
var toolCalls []FunctionCall
for _, part := range choice.Content.Parts {
if part.Text != "" {
content += part.Text
}
if part.FunctionCall != nil {
toolCalls = append(toolCalls, *part.FunctionCall)
}
}
if len(toolCalls) > 0 {
messages, err := handleToolCalls(
params, content, convertToolCallToAPI(toolCalls), callback, messages,
)
if err != nil {
return content, err
}
return c.CreateChatCompletion(ctx, params, messages, callback)
}
if callback != nil {
callback(model.Message{
Role: model.MessageRoleAssistant,
Content: content,
})
}
return content, nil
}
func (c *Client) CreateChatCompletionStream(
ctx context.Context,
params model.RequestParameters,
messages []model.Message,
callback provider.ReplyCallback,
output chan<- string,
) (string, error) {
if len(messages) == 0 {
return "", fmt.Errorf("Can't create completion from no messages")
}
req, err := createGenerateContentRequest(params, messages)
if err != nil {
return "", err
}
jsonData, err := json.Marshal(req)
if err != nil {
return "", err
}
url := fmt.Sprintf(
"%s/v1beta/models/%s:streamGenerateContent?key=%s&alt=sse",
c.BaseURL, params.Model, c.APIKey,
)
httpReq, err := http.NewRequest("POST", url, bytes.NewBuffer(jsonData))
if err != nil {
return "", err
}
resp, err := c.sendRequest(ctx, httpReq)
if err != nil {
return "", err
}
defer resp.Body.Close()
content := strings.Builder{}
lastMessage := messages[len(messages)-1]
if lastMessage.Role.IsAssistant() {
content.WriteString(lastMessage.Content)
}
var toolCalls []FunctionCall
reader := bufio.NewReader(resp.Body)
for {
line, err := reader.ReadBytes('\n')
if err != nil {
if err == io.EOF {
break
}
return "", err
}
line = bytes.TrimSpace(line)
if len(line) == 0 || !bytes.HasPrefix(line, []byte("data: ")) {
continue
}
line = bytes.TrimPrefix(line, []byte("data: "))
var streamResp GenerateContentResponse
err = json.Unmarshal(line, &streamResp)
if err != nil {
return "", err
}
for _, candidate := range streamResp.Candidates {
for _, part := range candidate.Content.Parts {
if part.FunctionCall != nil {
toolCalls = append(toolCalls, *part.FunctionCall)
} else if part.Text != "" {
output <- part.Text
content.WriteString(part.Text)
}
}
}
}
// If there are function calls, handle them and recurse
if len(toolCalls) > 0 {
messages, err := handleToolCalls(
params, content.String(), convertToolCallToAPI(toolCalls), callback, messages,
)
if err != nil {
return content.String(), err
}
return c.CreateChatCompletionStream(ctx, params, messages, callback, output)
}
if callback != nil {
callback(model.Message{
Role: model.MessageRoleAssistant,
Content: content.String(),
})
}
return content.String(), nil
}

View File

@ -1,80 +0,0 @@
package google
type Client struct {
APIKey string
BaseURL string
}
type ContentPart struct {
Text string `json:"text,omitempty"`
FunctionCall *FunctionCall `json:"functionCall,omitempty"`
FunctionResp *FunctionResponse `json:"functionResponse,omitempty"`
}
type FunctionCall struct {
Name string `json:"name"`
Args map[string]string `json:"args"`
}
type FunctionResponse struct {
Name string `json:"name"`
Response interface{} `json:"response"`
}
type Content struct {
Role string `json:"role"`
Parts []ContentPart `json:"parts"`
}
type GenerationConfig struct {
MaxOutputTokens *int `json:"maxOutputTokens,omitempty"`
Temperature *float32 `json:"temperature,omitempty"`
TopP *float32 `json:"topP,omitempty"`
TopK *int `json:"topK,omitempty"`
}
type GenerateContentRequest struct {
Contents []Content `json:"contents"`
Tools []Tool `json:"tools,omitempty"`
SystemInstructions string `json:"systemInstructions,omitempty"`
GenerationConfig *GenerationConfig `json:"generationConfig,omitempty"`
}
type Candidate struct {
Content Content `json:"content"`
FinishReason string `json:"finishReason"`
Index int `json:"index"`
}
type UsageMetadata struct {
PromptTokenCount int `json:"promptTokenCount"`
CandidatesTokenCount int `json:"candidatesTokenCount"`
TotalTokenCount int `json:"totalTokenCount"`
}
type GenerateContentResponse struct {
Candidates []Candidate `json:"candidates"`
UsageMetadata UsageMetadata `json:"usageMetadata"`
}
type Tool struct {
FunctionDeclarations []FunctionDeclaration `json:"functionDeclarations"`
}
type FunctionDeclaration struct {
Name string `json:"name"`
Description string `json:"description"`
Parameters ToolParameters `json:"parameters"`
}
type ToolParameters struct {
Type string `json:"type"`
Properties map[string]ToolParameter `json:"properties,omitempty"`
Required []string `json:"required,omitempty"`
}
type ToolParameter struct {
Type string `json:"type"`
Description string `json:"description"`
Values []string `json:"values,omitempty"`
}

View File

@ -13,26 +13,21 @@ import (
) )
type ConversationStore interface { type ConversationStore interface {
Conversations() ([]model.Conversation, error)
ConversationByShortName(shortName string) (*model.Conversation, error) ConversationByShortName(shortName string) (*model.Conversation, error)
ConversationShortNameCompletions(search string) []string ConversationShortNameCompletions(search string) []string
RootMessages(conversationID uint) ([]model.Message, error)
LatestConversationMessages() ([]model.Message, error)
StartConversation(messages ...model.Message) (*model.Conversation, []model.Message, error) SaveConversation(conversation *model.Conversation) error
UpdateConversation(conversation *model.Conversation) error
DeleteConversation(conversation *model.Conversation) error DeleteConversation(conversation *model.Conversation) error
CloneConversation(toClone model.Conversation) (*model.Conversation, uint, error)
MessageByID(messageID uint) (*model.Message, error) Messages(conversation *model.Conversation) ([]model.Message, error)
MessageReplies(messageID uint) ([]model.Message, error) LastMessage(conversation *model.Conversation) (*model.Message, error)
SaveMessage(message *model.Message) error
DeleteMessage(message *model.Message) error
UpdateMessage(message *model.Message) error UpdateMessage(message *model.Message) error
DeleteMessage(message *model.Message, prune bool) error AddReply(conversation *model.Conversation, message model.Message) (*model.Message, error)
CloneBranch(toClone model.Message) (*model.Message, uint, error)
Reply(to *model.Message, messages ...model.Message) ([]model.Message, error)
PathToRoot(message *model.Message) ([]model.Message, error)
PathToLeaf(message *model.Message) ([]model.Message, error)
} }
type SQLStore struct { type SQLStore struct {
@ -57,52 +52,47 @@ func NewSQLStore(db *gorm.DB) (*SQLStore, error) {
return &SQLStore{db, _sqids}, nil return &SQLStore{db, _sqids}, nil
} }
func (s *SQLStore) saveNewConversation(c *model.Conversation) error { func (s *SQLStore) SaveConversation(conversation *model.Conversation) error {
// Save the new conversation err := s.db.Save(&conversation).Error
err := s.db.Save(&c).Error
if err != nil { if err != nil {
return err return err
} }
// Generate and save its "short name" if !conversation.ShortName.Valid {
shortName, _ := s.sqids.Encode([]uint64{uint64(c.ID)}) shortName, _ := s.sqids.Encode([]uint64{uint64(conversation.ID)})
c.ShortName = sql.NullString{String: shortName, Valid: true} conversation.ShortName = sql.NullString{String: shortName, Valid: true}
return s.UpdateConversation(c) err = s.db.Save(&conversation).Error
} }
func (s *SQLStore) UpdateConversation(c *model.Conversation) error {
if c == nil || c.ID == 0 {
return fmt.Errorf("Conversation is nil or invalid (missing ID)")
}
return s.db.Updates(&c).Error
}
func (s *SQLStore) DeleteConversation(c *model.Conversation) error {
// Delete messages first
err := s.db.Where("conversation_id = ?", c.ID).Delete(&model.Message{}).Error
if err != nil {
return err return err
} }
return s.db.Delete(&c).Error
func (s *SQLStore) DeleteConversation(conversation *model.Conversation) error {
s.db.Where("conversation_id = ?", conversation.ID).Delete(&model.Message{})
return s.db.Delete(&conversation).Error
} }
func (s *SQLStore) DeleteMessage(message *model.Message, prune bool) error { func (s *SQLStore) SaveMessage(message *model.Message) error {
panic("Not yet implemented") return s.db.Create(message).Error
//return s.db.Delete(&message).Error
} }
func (s *SQLStore) UpdateMessage(m *model.Message) error { func (s *SQLStore) DeleteMessage(message *model.Message) error {
if m == nil || m.ID == 0 { return s.db.Delete(&message).Error
return fmt.Errorf("Message is nil or invalid (missing ID)")
} }
return s.db.Updates(&m).Error
func (s *SQLStore) UpdateMessage(message *model.Message) error {
return s.db.Updates(&message).Error
}
func (s *SQLStore) Conversations() ([]model.Conversation, error) {
var conversations []model.Conversation
err := s.db.Find(&conversations).Error
return conversations, err
} }
func (s *SQLStore) ConversationShortNameCompletions(shortName string) []string { func (s *SQLStore) ConversationShortNameCompletions(shortName string) []string {
var conversations []model.Conversation var completions []string
// ignore error for completions conversations, _ := s.Conversations() // ignore error for completions
s.db.Find(&conversations)
completions := make([]string, 0, len(conversations))
for _, conversation := range conversations { for _, conversation := range conversations {
if shortName == "" || strings.HasPrefix(conversation.ShortName.String, shortName) { if shortName == "" || strings.HasPrefix(conversation.ShortName.String, shortName) {
completions = append(completions, fmt.Sprintf("%s\t%s", conversation.ShortName.String, conversation.Title)) completions = append(completions, fmt.Sprintf("%s\t%s", conversation.ShortName.String, conversation.Title))
@ -116,249 +106,27 @@ func (s *SQLStore) ConversationByShortName(shortName string) (*model.Conversatio
return nil, errors.New("shortName is empty") return nil, errors.New("shortName is empty")
} }
var conversation model.Conversation var conversation model.Conversation
err := s.db.Preload("SelectedRoot").Where("short_name = ?", shortName).Find(&conversation).Error err := s.db.Where("short_name = ?", shortName).Find(&conversation).Error
return &conversation, err return &conversation, err
} }
func (s *SQLStore) RootMessages(conversationID uint) ([]model.Message, error) { func (s *SQLStore) Messages(conversation *model.Conversation) ([]model.Message, error) {
var rootMessages []model.Message var messages []model.Message
err := s.db.Where("conversation_id = ? AND parent_id IS NULL", conversationID).Find(&rootMessages).Error err := s.db.Where("conversation_id = ?", conversation.ID).Find(&messages).Error
if err != nil { return messages, err
return nil, err
}
return rootMessages, nil
} }
func (s *SQLStore) MessageByID(messageID uint) (*model.Message, error) { func (s *SQLStore) LastMessage(conversation *model.Conversation) (*model.Message, error) {
var message model.Message var message model.Message
err := s.db.Preload("Parent").Preload("Replies").Preload("SelectedReply").Where("id = ?", messageID).Find(&message).Error err := s.db.Where("conversation_id = ?", conversation.ID).Last(&message).Error
return &message, err return &message, err
} }
func (s *SQLStore) MessageReplies(messageID uint) ([]model.Message, error) { // AddReply adds the given messages as a reply to the given conversation, can be
var replies []model.Message // used to easily copy a message associated with one conversation, to another
err := s.db.Where("parent_id = ?", messageID).Find(&replies).Error func (s *SQLStore) AddReply(c *model.Conversation, m model.Message) (*model.Message, error) {
return replies, err m.ConversationID = c.ID
} m.ID = 0
m.CreatedAt = time.Time{}
// StartConversation starts a new conversation with the provided messages return &m, s.SaveMessage(&m)
func (s *SQLStore) StartConversation(messages ...model.Message) (*model.Conversation, []model.Message, error) {
if len(messages) == 0 {
return nil, nil, fmt.Errorf("Must provide at least 1 message")
}
// Create new conversation
conversation := &model.Conversation{}
err := s.saveNewConversation(conversation)
if err != nil {
return nil, nil, err
}
// Create first message
messages[0].ConversationID = conversation.ID
err = s.db.Create(&messages[0]).Error
if err != nil {
return nil, nil, err
}
// Update conversation's selected root message
conversation.SelectedRoot = &messages[0]
err = s.UpdateConversation(conversation)
if err != nil {
return nil, nil, err
}
// Add additional replies to conversation
if len(messages) > 1 {
newMessages, err := s.Reply(&messages[0], messages[1:]...)
if err != nil {
return nil, nil, err
}
messages = append([]model.Message{messages[0]}, newMessages...)
}
return conversation, messages, nil
}
// CloneConversation clones the given conversation and all of its root meesages
func (s *SQLStore) CloneConversation(toClone model.Conversation) (*model.Conversation, uint, error) {
rootMessages, err := s.RootMessages(toClone.ID)
if err != nil {
return nil, 0, err
}
clone := &model.Conversation{
Title: toClone.Title + " - Clone",
}
if err := s.saveNewConversation(clone); err != nil {
return nil, 0, fmt.Errorf("Could not create clone: %s", err)
}
var errors []error
var messageCnt uint = 0
for _, root := range rootMessages {
messageCnt++
newRoot := root
newRoot.ConversationID = clone.ID
cloned, count, err := s.CloneBranch(newRoot)
if err != nil {
errors = append(errors, err)
continue
}
messageCnt += count
if root.ID == *toClone.SelectedRootID {
clone.SelectedRootID = &cloned.ID
if err := s.UpdateConversation(clone); err != nil {
errors = append(errors, fmt.Errorf("Could not set selected root on clone: %v", err))
}
}
}
if len(errors) > 0 {
return nil, 0, fmt.Errorf("Messages failed to be cloned: %v", errors)
}
return clone, messageCnt, nil
}
// Reply replies to the given parentMessage with a series of messages
func (s *SQLStore) Reply(parentMessage *model.Message, messages ...model.Message) ([]model.Message, error) {
var savedMessages []model.Message
currentParent := parentMessage
err := s.db.Transaction(func(tx *gorm.DB) error {
for i := range messages {
message := messages[i]
message.ConversationID = currentParent.ConversationID
message.ParentID = &currentParent.ID
message.ID = 0
message.CreatedAt = time.Time{}
if err := tx.Create(&message).Error; err != nil {
return err
}
// update parent selected reply
currentParent.SelectedReply = &message
if err := tx.Model(currentParent).Update("selected_reply_id", message.ID).Error; err != nil {
return err
}
savedMessages = append(savedMessages, message)
currentParent = &message
}
return nil
})
return savedMessages, err
}
// CloneBranch returns a deep clone of the given message and its replies, returning
// a new message object. The new message will be attached to the same parent as
// the message to clone.
func (s *SQLStore) CloneBranch(messageToClone model.Message) (*model.Message, uint, error) {
newMessage := messageToClone
newMessage.ID = 0
newMessage.Replies = nil
newMessage.SelectedReplyID = nil
newMessage.SelectedReply = nil
if err := s.db.Create(&newMessage).Error; err != nil {
return nil, 0, fmt.Errorf("Could not clone message: %s", err)
}
originalReplies, err := s.MessageReplies(messageToClone.ID)
if err != nil {
return nil, 0, fmt.Errorf("Could not fetch message %d replies: %v", messageToClone.ID, err)
}
var replyCount uint = 0
for _, reply := range originalReplies {
replyCount++
newReply := reply
newReply.ConversationID = messageToClone.ConversationID
newReply.ParentID = &newMessage.ID
newReply.Parent = &newMessage
res, c, err := s.CloneBranch(newReply)
if err != nil {
return nil, 0, err
}
newMessage.Replies = append(newMessage.Replies, *res)
replyCount += c
if reply.ID == *messageToClone.SelectedReplyID {
newMessage.SelectedReplyID = &res.ID
if err := s.UpdateMessage(&newMessage); err != nil {
return nil, 0, fmt.Errorf("Could not update parent select reply ID: %v", err)
}
}
}
return &newMessage, replyCount, nil
}
// PathToRoot traverses message Parent until reaching the tree root
func (s *SQLStore) PathToRoot(message *model.Message) ([]model.Message, error) {
if message == nil {
return nil, fmt.Errorf("Message is nil")
}
var path []model.Message
current := message
for {
path = append([]model.Message{*current}, path...)
if current.Parent == nil {
break
}
var err error
current, err = s.MessageByID(*current.ParentID)
if err != nil {
return nil, fmt.Errorf("finding parent message: %w", err)
}
}
return path, nil
}
// PathToLeaf traverses message SelectedReply until reaching a tree leaf
func (s *SQLStore) PathToLeaf(message *model.Message) ([]model.Message, error) {
if message == nil {
return nil, fmt.Errorf("Message is nil")
}
var path []model.Message
current := message
for {
path = append(path, *current)
if current.SelectedReplyID == nil {
break
}
var err error
current, err = s.MessageByID(*current.SelectedReplyID)
if err != nil {
return nil, fmt.Errorf("finding selected reply: %w", err)
}
}
return path, nil
}
func (s *SQLStore) LatestConversationMessages() ([]model.Message, error) {
var latestMessages []model.Message
subQuery := s.db.Model(&model.Message{}).
Select("MAX(created_at) as max_created_at, conversation_id").
Group("conversation_id")
err := s.db.Model(&model.Message{}).
Joins("JOIN (?) as sub on messages.conversation_id = sub.conversation_id AND messages.created_at = sub.max_created_at", subQuery).
Order("created_at DESC").
Preload("Conversation").
Find(&latestMessages).Error
if err != nil {
return nil, err
}
return latestMessages, nil
} }

View File

@ -11,9 +11,7 @@ import (
toolutil "git.mlow.ca/mlow/lmcli/pkg/lmcli/tools/util" toolutil "git.mlow.ca/mlow/lmcli/pkg/lmcli/tools/util"
) )
const TREE_DESCRIPTION = `Retrieve a tree-like view of a directory's contents. const TREE_DESCRIPTION = `Retrieve a tree view of a directory's contents.
Use these results for your own reference in completing your task, they do not need to be shown to the user.
Example result: Example result:
{ {
@ -37,45 +35,48 @@ var DirTreeTool = model.Tool{
Description: "If set, display the tree starting from this path relative to the current one.", Description: "If set, display the tree starting from this path relative to the current one.",
}, },
{ {
Name: "depth", Name: "max_depth",
Type: "integer", Type: "integer",
Description: "Depth of directory recursion. Default 0. Use -1 for unlimited.", Description: "Maximum depth of recursion. Default is unlimited.",
}, },
}, },
Impl: func(tool *model.Tool, args map[string]interface{}) (string, error) { Impl: func(tool *model.Tool, args map[string]interface{}) (string, error) {
var relativeDir string var relativeDir string
if tmp, ok := args["relative_path"]; ok { tmp, ok := args["relative_dir"]
if ok {
relativeDir, ok = tmp.(string) relativeDir, ok = tmp.(string)
if !ok { if !ok {
return "", fmt.Errorf("expected string for relative_path, got %T", tmp) return "", fmt.Errorf("Invalid relative_dir in function arguments: %v", tmp)
} }
} }
var depth int = 0 // Default value if not provided var maxDepth int = -1
if tmp, ok := args["depth"]; ok { tmp, ok = args["max_depth"]
switch v := tmp.(type) { if ok {
case float64: maxDepth, ok = tmp.(int)
depth = int(v) if !ok {
case string: if tmps, ok := tmp.(string); ok {
var err error tmpi, err := strconv.Atoi(tmps)
if depth, err = strconv.Atoi(v); err != nil { maxDepth = tmpi
return "", fmt.Errorf("invalid `depth` value, expected integer but got string that cannot convert: %v", tmp) if err != nil {
return "", fmt.Errorf("Invalid max_depth in function arguments: %v", tmp)
}
} else {
return "", fmt.Errorf("Invalid max_depth in function arguments: %v", tmp)
} }
default:
return "", fmt.Errorf("expected int or string for max_depth, got %T", tmp)
} }
} }
result := tree(relativeDir, depth) result := tree(relativeDir, maxDepth)
ret, err := result.ToJson() ret, err := result.ToJson()
if err != nil { if err != nil {
return "", fmt.Errorf("could not serialize result: %v", err) return "", fmt.Errorf("Could not serialize result: %v", err)
} }
return ret, nil return ret, nil
}, },
} }
func tree(path string, depth int) model.CallResult { func tree(path string, maxDepth int) model.CallResult {
if path == "" { if path == "" {
path = "." path = "."
} }
@ -86,7 +87,7 @@ func tree(path string, depth int) model.CallResult {
var treeOutput strings.Builder var treeOutput strings.Builder
treeOutput.WriteString(path + "\n") treeOutput.WriteString(path + "\n")
err := buildTree(&treeOutput, path, "", depth) err := buildTree(&treeOutput, path, "", maxDepth)
if err != nil { if err != nil {
return model.CallResult{ return model.CallResult{
Message: err.Error(), Message: err.Error(),
@ -96,7 +97,7 @@ func tree(path string, depth int) model.CallResult {
return model.CallResult{Result: treeOutput.String()} return model.CallResult{Result: treeOutput.String()}
} }
func buildTree(output *strings.Builder, path string, prefix string, depth int) error { func buildTree(output *strings.Builder, path string, prefix string, maxDepth int) error {
files, err := os.ReadDir(path) files, err := os.ReadDir(path)
if err != nil { if err != nil {
return err return err
@ -123,14 +124,14 @@ func buildTree(output *strings.Builder, path string, prefix string, depth int) e
output.WriteString(prefix + branch + file.Name()) output.WriteString(prefix + branch + file.Name())
if file.IsDir() { if file.IsDir() {
output.WriteString("/\n") output.WriteString("/\n")
if depth != 0 { if maxDepth != 0 {
var nextPrefix string var nextPrefix string
if isLast { if isLast {
nextPrefix = prefix + " " nextPrefix = prefix + " "
} else { } else {
nextPrefix = prefix + "│ " nextPrefix = prefix + "│ "
} }
buildTree(output, filepath.Join(path, file.Name()), nextPrefix, depth-1) buildTree(output, filepath.Join(path, file.Name()), nextPrefix, maxDepth-1)
} }
} else { } else {
output.WriteString(sizeStr + "\n") output.WriteString(sizeStr + "\n")
@ -139,3 +140,4 @@ func buildTree(output *strings.Builder, path string, prefix string, depth int) e
return nil return nil
} }

View File

@ -9,9 +9,7 @@ import (
toolutil "git.mlow.ca/mlow/lmcli/pkg/lmcli/tools/util" toolutil "git.mlow.ca/mlow/lmcli/pkg/lmcli/tools/util"
) )
const READ_FILE_DESCRIPTION = `Retrieve the contents of a text file relative to the current working directory. const READ_FILE_DESCRIPTION = `Read the contents of a text file relative to the current working directory.
Use the file contents for your own reference in completing your task, they do not need to be shown to the user.
Each line of the returned content is prefixed with its line number and a tab (\t). Each line of the returned content is prefixed with its line number and a tab (\t).

View File

@ -66,10 +66,6 @@ type chatModel struct {
replyChunkChan chan string replyChunkChan chan string
persistence bool // whether we will save new messages in the conversation persistence bool // whether we will save new messages in the conversation
tokenCount uint
startTime time.Time
elapsed time.Duration
// ui state // ui state
focus focusState focus focusState
wrap bool // whether message content is wrapped to viewport width wrap bool // whether message content is wrapped to viewport width
@ -286,9 +282,6 @@ func (m chatModel) Update(msg tea.Msg) (chatModel, tea.Cmd) {
} }
m.updateContent() m.updateContent()
cmds = append(cmds, m.waitForChunk()) // wait for the next chunk cmds = append(cmds, m.waitForChunk()) // wait for the next chunk
m.tokenCount++
m.elapsed = time.Now().Sub(m.startTime)
case msgAssistantReply: case msgAssistantReply:
// the last reply that was being worked on is finished // the last reply that was being worked on is finished
reply := models.Message(msg) reply := models.Message(msg)
@ -307,9 +300,14 @@ func (m chatModel) Update(msg tea.Msg) (chatModel, tea.Cmd) {
} }
if m.persistence { if m.persistence {
err := m.persistConversation() var err error
if m.conversation.ID == 0 {
err = m.ctx.Store.SaveConversation(m.conversation)
}
if err != nil { if err != nil {
cmds = append(cmds, wrapError(err)) cmds = append(cmds, wrapError(err))
} else {
cmds = append(cmds, m.persistConversation())
} }
} }
@ -336,7 +334,7 @@ func (m chatModel) Update(msg tea.Msg) (chatModel, tea.Cmd) {
title := string(msg) title := string(msg)
m.conversation.Title = title m.conversation.Title = title
if m.persistence { if m.persistence {
err := m.ctx.Store.UpdateConversation(m.conversation) err := m.ctx.Store.SaveConversation(m.conversation)
if err != nil { if err != nil {
cmds = append(cmds, wrapError(err)) cmds = append(cmds, wrapError(err))
} }
@ -464,8 +462,8 @@ func (m *chatModel) handleInputKey(msg tea.KeyMsg) (bool, tea.Cmd) {
m.input.Blur() m.input.Blur()
return true, nil return true, nil
case "ctrl+s": case "ctrl+s":
input := strings.TrimSpace(m.input.Value()) userInput := strings.TrimSpace(m.input.Value())
if input == "" { if strings.TrimSpace(userInput) == "" {
return true, nil return true, nil
} }
@ -473,20 +471,36 @@ func (m *chatModel) handleInputKey(msg tea.KeyMsg) (bool, tea.Cmd) {
return true, wrapError(fmt.Errorf("Can't reply to a user message")) return true, wrapError(fmt.Errorf("Can't reply to a user message"))
} }
m.addMessage(models.Message{ reply := models.Message{
Role: models.MessageRoleUser, Role: models.MessageRoleUser,
Content: input, Content: userInput,
}) }
m.input.SetValue("")
if m.persistence { if m.persistence {
err := m.persistConversation() var err error
if m.conversation.ID == 0 {
err = m.ctx.Store.SaveConversation(m.conversation)
}
if err != nil { if err != nil {
return true, wrapError(err) return true, wrapError(err)
} }
// ensure all messages up to the one we're about to add are persisted
cmd := m.persistConversation()
if cmd != nil {
return true, cmd
} }
savedReply, err := m.ctx.Store.AddReply(m.conversation, reply)
if err != nil {
return true, wrapError(err)
}
reply = *savedReply
}
m.input.SetValue("")
m.addMessage(reply)
m.updateContent() m.updateContent()
m.content.GotoBottom() m.content.GotoBottom()
return true, m.promptLLM() return true, m.promptLLM()
@ -679,16 +693,10 @@ func (m *chatModel) footerView() string {
saving, saving,
segmentStyle.Render(status), segmentStyle.Render(status),
} }
rightSegments := []string{} rightSegments := []string{
segmentStyle.Render(fmt.Sprintf("Model: %s", *m.ctx.Config.Defaults.Model)),
if m.elapsed > 0 && m.tokenCount > 0 {
throughput := fmt.Sprintf("%.0f t/sec", float64(m.tokenCount)/m.elapsed.Seconds())
rightSegments = append(rightSegments, segmentStyle.Render(throughput))
} }
model := fmt.Sprintf("Model: %s", *m.ctx.Config.Defaults.Model)
rightSegments = append(rightSegments, segmentStyle.Render(model))
left := strings.Join(leftSegments, segmentSeparator) left := strings.Join(leftSegments, segmentSeparator)
right := strings.Join(rightSegments, segmentSeparator) right := strings.Join(rightSegments, segmentSeparator)
@ -762,7 +770,7 @@ func (m *chatModel) loadConversation(shortname string) tea.Cmd {
func (m *chatModel) loadMessages(c *models.Conversation) tea.Cmd { func (m *chatModel) loadMessages(c *models.Conversation) tea.Cmd {
return func() tea.Msg { return func() tea.Msg {
messages, err := m.ctx.Store.PathToLeaf(c.SelectedRoot) messages, err := m.ctx.Store.Messages(c)
if err != nil { if err != nil {
return msgError(fmt.Errorf("Could not load conversation messages: %v\n", err)) return msgError(fmt.Errorf("Could not load conversation messages: %v\n", err))
} }
@ -770,48 +778,62 @@ func (m *chatModel) loadMessages(c *models.Conversation) tea.Cmd {
} }
} }
func (m *chatModel) persistConversation() error { func (m *chatModel) persistConversation() tea.Cmd {
if m.conversation.ID == 0 { existingMessages, err := m.ctx.Store.Messages(m.conversation)
// Start a new conversation with all messages so far
c, messages, err := m.ctx.Store.StartConversation(m.messages...)
if err != nil { if err != nil {
return err return wrapError(fmt.Errorf("Could not retrieve existing conversation messages while trying to save: %v", err))
}
m.conversation = c
m.messages = messages
return nil
} }
// else, we'll handle updating an existing conversation's messages existingById := make(map[uint]*models.Message, len(existingMessages))
for i := 0; i < len(m.messages); i++ { for _, msg := range existingMessages {
if m.messages[i].ID > 0 { existingById[msg.ID] = &msg
// message has an ID, update its contents
// TODO: check for content/tool equality before updating?
err := m.ctx.Store.UpdateMessage(&m.messages[i])
if err != nil {
return err
} }
} else if i > 0 {
// messages is new, so add it as a reply to previous message currentById := make(map[uint]*models.Message, len(m.messages))
saved, err := m.ctx.Store.Reply(&m.messages[i-1], m.messages[i]) for _, msg := range m.messages {
if err != nil { currentById[msg.ID] = &msg
return err }
for _, msg := range existingMessages {
_, ok := currentById[msg.ID]
if !ok {
err := m.ctx.Store.DeleteMessage(&msg)
if err != nil {
return wrapError(fmt.Errorf("Failed to remove messages: %v", err))
}
}
}
for i, msg := range m.messages {
if msg.ID > 0 {
exist, ok := existingById[msg.ID]
if ok {
if msg.Content == exist.Content {
continue
}
// update message when contents don't match that of store
err := m.ctx.Store.UpdateMessage(&msg)
if err != nil {
return wrapError(err)
} }
m.messages[i] = saved[0]
} else { } else {
// message has no id and no previous messages to add it to // this would be quite odd... and I'm not sure how to handle
// this shouldn't happen? // it at the time of writing this
return fmt.Errorf("Error: no messages to reply to") }
} else {
newMessage, err := m.ctx.Store.AddReply(m.conversation, msg)
if err != nil {
return wrapError(err)
}
m.setMessage(i, *newMessage)
} }
} }
return nil return nil
} }
func (m *chatModel) generateConversationTitle() tea.Cmd { func (m *chatModel) generateConversationTitle() tea.Cmd {
return func() tea.Msg { return func() tea.Msg {
title, err := cmdutil.GenerateTitle(m.ctx, m.messages) title, err := cmdutil.GenerateTitle(m.ctx, m.conversation)
if err != nil { if err != nil {
return msgError(err) return msgError(err)
} }
@ -835,10 +857,6 @@ func (m *chatModel) promptLLM() tea.Cmd {
m.waitingForReply = true m.waitingForReply = true
m.status = "Press ctrl+c to cancel" m.status = "Press ctrl+c to cancel"
m.tokenCount = 0
m.startTime = time.Now()
m.elapsed = 0
return func() tea.Msg { return func() tea.Msg {
completionProvider, err := m.ctx.GetCompletionProvider(*m.ctx.Config.Defaults.Model) completionProvider, err := m.ctx.GetCompletionProvider(*m.ctx.Config.Defaults.Model)
if err != nil { if err != nil {

View File

@ -2,6 +2,7 @@ package tui
import ( import (
"fmt" "fmt"
"slices"
"strings" "strings"
"time" "time"
@ -114,7 +115,7 @@ func (m *conversationsModel) handleResize(width, height int) {
func (m conversationsModel) Update(msg tea.Msg) (conversationsModel, tea.Cmd) { func (m conversationsModel) Update(msg tea.Msg) (conversationsModel, tea.Cmd) {
var cmds []tea.Cmd var cmds []tea.Cmd
switch msg := msg.(type) { switch msg := msg.(type) {
case msgStateEnter: case msgStateChange:
cmds = append(cmds, m.loadConversations()) cmds = append(cmds, m.loadConversations())
m.content.SetContent(m.renderConversationList()) m.content.SetContent(m.renderConversationList())
case tea.WindowSizeMsg: case tea.WindowSizeMsg:
@ -144,16 +145,24 @@ func (m conversationsModel) Update(msg tea.Msg) (conversationsModel, tea.Cmd) {
func (m *conversationsModel) loadConversations() tea.Cmd { func (m *conversationsModel) loadConversations() tea.Cmd {
return func() tea.Msg { return func() tea.Msg {
messages, err := m.ctx.Store.LatestConversationMessages() conversations, err := m.ctx.Store.Conversations()
if err != nil { if err != nil {
return msgError(fmt.Errorf("Could not load conversations: %v", err)) return msgError(fmt.Errorf("Could not load conversations: %v", err))
} }
loaded := make([]loadedConversation, len(messages)) loaded := make([]loadedConversation, len(conversations))
for i, m := range messages { for i, c := range conversations {
loaded[i].lastReply = m lastMessage, err := m.ctx.Store.LastMessage(&c)
loaded[i].conv = m.Conversation if err != nil {
return msgError(err)
} }
loaded[i].conv = c
loaded[i].lastReply = *lastMessage
}
slices.SortFunc(loaded, func(a, b loadedConversation) int {
return b.lastReply.CreatedAt.Compare(a.lastReply.CreatedAt)
})
return msgConversationsLoaded(loaded) return msgConversationsLoaded(loaded)
} }

View File

@ -147,16 +147,16 @@ func SetStructDefaults(data interface{}) bool {
case reflect.String: case reflect.String:
defaultValue := defaultTag defaultValue := defaultTag
field.Set(reflect.ValueOf(&defaultValue)) field.Set(reflect.ValueOf(&defaultValue))
case reflect.Uint, reflect.Uint32, reflect.Uint64:
intValue, _ := strconv.ParseUint(defaultTag, 10, e.Bits())
field.Set(reflect.New(e))
field.Elem().SetUint(intValue)
case reflect.Int, reflect.Int32, reflect.Int64: case reflect.Int, reflect.Int32, reflect.Int64:
intValue, _ := strconv.ParseInt(defaultTag, 10, e.Bits()) intValue, _ := strconv.ParseInt(defaultTag, 10, 64)
field.Set(reflect.New(e)) field.Set(reflect.New(e))
field.Elem().SetInt(intValue) field.Elem().SetInt(intValue)
case reflect.Float32, reflect.Float64: case reflect.Float32:
floatValue, _ := strconv.ParseFloat(defaultTag, e.Bits()) floatValue, _ := strconv.ParseFloat(defaultTag, 32)
field.Set(reflect.New(e))
field.Elem().SetFloat(floatValue)
case reflect.Float64:
floatValue, _ := strconv.ParseFloat(defaultTag, 64)
field.Set(reflect.New(e)) field.Set(reflect.New(e))
field.Elem().SetFloat(floatValue) field.Elem().SetFloat(floatValue)
case reflect.Bool: case reflect.Bool: