diff --git a/go.mod b/go.mod index dbc0be7..c572046 100644 --- a/go.mod +++ b/go.mod @@ -4,8 +4,8 @@ go 1.21 require ( github.com/alecthomas/chroma/v2 v2.11.1 + github.com/charmbracelet/lipgloss v0.10.0 github.com/go-yaml/yaml v2.1.0+incompatible - github.com/gookit/color v1.5.4 github.com/sashabaranov/go-openai v1.17.7 github.com/spf13/cobra v1.8.0 github.com/sqids/sqids-go v0.4.1 @@ -14,14 +14,20 @@ require ( ) require ( + github.com/aymanbagabas/go-osc52/v2 v2.0.1 // indirect github.com/dlclark/regexp2 v1.10.0 // indirect github.com/inconshreveable/mousetrap v1.1.0 // indirect github.com/jinzhu/inflection v1.0.0 // indirect github.com/jinzhu/now v1.1.5 // indirect github.com/kr/pretty v0.3.1 // indirect + github.com/lucasb-eyer/go-colorful v1.2.0 // indirect + github.com/mattn/go-isatty v0.0.18 // indirect + github.com/mattn/go-runewidth v0.0.15 // indirect github.com/mattn/go-sqlite3 v1.14.18 // indirect + github.com/muesli/reflow v0.3.0 // indirect + github.com/muesli/termenv v0.15.2 // indirect + github.com/rivo/uniseg v0.4.7 // indirect github.com/spf13/pflag v1.0.5 // indirect - github.com/xo/terminfo v0.0.0-20210125001918-ca9a967f8778 // indirect golang.org/x/sys v0.14.0 // indirect gopkg.in/check.v1 v1.0.0-20190902080502-41f04d3bba15 // indirect gopkg.in/yaml.v2 v2.2.2 // indirect diff --git a/go.sum b/go.sum index c9309f8..a450795 100644 --- a/go.sum +++ b/go.sum @@ -4,16 +4,16 @@ github.com/alecthomas/chroma/v2 v2.11.1 h1:m9uUtgcdAwgfFNxuqj7AIG75jD2YmL61BBIJW github.com/alecthomas/chroma/v2 v2.11.1/go.mod h1:4TQu7gdfuPjSh76j78ietmqh9LiurGF0EpseFXdKMBw= github.com/alecthomas/repr v0.2.0 h1:HAzS41CIzNW5syS8Mf9UwXhNH1J9aix/BvDRf1Ml2Yk= github.com/alecthomas/repr v0.2.0/go.mod h1:Fr0507jx4eOXV7AlPV6AVZLYrLIuIeSOWtW57eE/O/4= +github.com/aymanbagabas/go-osc52/v2 v2.0.1 h1:HwpRHbFMcZLEVr42D4p7XBqjyuxQH5SMiErDT4WkJ2k= +github.com/aymanbagabas/go-osc52/v2 v2.0.1/go.mod h1:uYgXzlJ7ZpABp8OJ+exZzJJhRNQ2ASbcXHWsFqH8hp8= +github.com/charmbracelet/lipgloss v0.10.0 h1:KWeXFSexGcfahHX+54URiZGkBFazf70JNMtwg/AFW3s= +github.com/charmbracelet/lipgloss v0.10.0/go.mod h1:Wig9DSfvANsxqkRsqj6x87irdy123SR4dOXlKa91ciE= github.com/cpuguy83/go-md2man/v2 v2.0.3/go.mod h1:tgQtvFlXSQOSOSIRvRPT7W67SCa46tRHOmNcaadrF8o= github.com/creack/pty v1.1.9/go.mod h1:oKZEueFk5CKHvIhNR5MUki03XCEU+Q6VDXinZuGJ33E= -github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= -github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/dlclark/regexp2 v1.10.0 h1:+/GIL799phkJqYW+3YbOd8LCcbHzT0Pbo8zl70MHsq0= github.com/dlclark/regexp2 v1.10.0/go.mod h1:DHkYz0B9wPfa6wondMfaivmHpzrQ3v9q8cnmRbL6yW8= github.com/go-yaml/yaml v2.1.0+incompatible h1:RYi2hDdss1u4YE7GwixGzWwVo47T8UQwnTLB6vQiq+o= github.com/go-yaml/yaml v2.1.0+incompatible/go.mod h1:w2MrLa16VYP0jy6N7M5kHaCkaLENm+P+Tv+MfurjSw0= -github.com/gookit/color v1.5.4 h1:FZmqs7XOyGgCAxmWyPslpiok1k05wmY3SJTytgvYFs0= -github.com/gookit/color v1.5.4/go.mod h1:pZJOeOS8DM43rXbp4AZo1n9zCU2qjpcRko0b6/QJi9w= github.com/hexops/gotextdiff v1.0.3 h1:gitA9+qJrrTCsiCl7+kh75nPqQt1cx4ZkudSTLoUqJM= github.com/hexops/gotextdiff v1.0.3/go.mod h1:pSWU5MAI3yDq+fZBTazCSJysOMbxWL1BSow5/V2vxeg= github.com/inconshreveable/mousetrap v1.1.0 h1:wN+x4NVGpMsO7ErUn/mUI3vEoE6Jt13X2s0bqwp9tc8= @@ -26,11 +26,24 @@ github.com/kr/pretty v0.3.1 h1:flRD4NNwYAUpkphVc1HcthR4KEIFJ65n8Mw5qdRn3LE= github.com/kr/pretty v0.3.1/go.mod h1:hoEshYVHaxMs3cyo3Yncou5ZscifuDolrwPKZanG3xk= github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY= github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE= +github.com/lucasb-eyer/go-colorful v1.2.0 h1:1nnpGOrhyZZuNyfu1QjKiUICQ74+3FNCN69Aj6K7nkY= +github.com/lucasb-eyer/go-colorful v1.2.0/go.mod h1:R4dSotOR9KMtayYi1e77YzuveK+i7ruzyGqttikkLy0= +github.com/mattn/go-isatty v0.0.18 h1:DOKFKCQ7FNG2L1rbrmstDN4QVRdS89Nkh85u68Uwp98= +github.com/mattn/go-isatty v0.0.18/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y= +github.com/mattn/go-runewidth v0.0.12/go.mod h1:RAqKPSqVFrSLVXbA8x7dzmKdmGzieGRCM46jaSJTDAk= +github.com/mattn/go-runewidth v0.0.15 h1:UNAjwbU9l54TA3KzvqLGxwWjHmMgBUVhBiTjelZgg3U= +github.com/mattn/go-runewidth v0.0.15/go.mod h1:Jdepj2loyihRzMpdS35Xk/zdY8IAYHsh153qUoGf23w= github.com/mattn/go-sqlite3 v1.14.18 h1:JL0eqdCOq6DJVNPSvArO/bIV9/P7fbGrV00LZHc+5aI= github.com/mattn/go-sqlite3 v1.14.18/go.mod h1:2eHXhiwb8IkHr+BDWZGa96P6+rkvnG63S2DGjv9HUNg= +github.com/muesli/reflow v0.3.0 h1:IFsN6K9NfGtjeggFP+68I4chLZV2yIKsXJFNZ+eWh6s= +github.com/muesli/reflow v0.3.0/go.mod h1:pbwTDkVPibjO2kyvBQRBxTWEEGDGq0FlB1BIKtnHY/8= +github.com/muesli/termenv v0.15.2 h1:GohcuySI0QmI3wN8Ok9PtKGkgkFIk7y6Vpb5PvrY+Wo= +github.com/muesli/termenv v0.15.2/go.mod h1:Epx+iuz8sNs7mNKhxzH4fWXGNpZwUaJKRS1noLXviQ8= github.com/pkg/diff v0.0.0-20210226163009-20ebb0f2a09e/go.mod h1:pJLUxLENpZxwdsKMEsNbx1VGcRFpLqf3715MtcvvzbA= -github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= -github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= +github.com/rivo/uniseg v0.1.0/go.mod h1:J6wj4VEh+S6ZtnVlnTBMWIodfgj8LQOQFoIToxlJtxc= +github.com/rivo/uniseg v0.2.0/go.mod h1:J6wj4VEh+S6ZtnVlnTBMWIodfgj8LQOQFoIToxlJtxc= +github.com/rivo/uniseg v0.4.7 h1:WUdvkW8uEhrYfLC4ZzdpI2ztxP1I582+49Oc5Mq64VQ= +github.com/rivo/uniseg v0.4.7/go.mod h1:FN3SvrM+Zdj16jyLfmOkMNblXMcoc8DfTHruCPUcx88= github.com/rogpeppe/go-internal v1.9.0 h1:73kH8U+JUqXU8lRuOHeVHaa/SZPifC7BkcraZVejAe8= github.com/rogpeppe/go-internal v1.9.0/go.mod h1:WtVeX8xhTBvf0smdhujwtBcq4Qrzq/fJaraNFVN+nFs= github.com/russross/blackfriday/v2 v2.1.0/go.mod h1:+Rmxgy9KzJVeS9/2gXHxylqXiyQDYRxCVz55jmeOWTM= @@ -42,10 +55,7 @@ github.com/spf13/pflag v1.0.5 h1:iy+VFUOCP1a+8yFto/drg2CJ5u0yRoB7fZw3DKv/JXA= github.com/spf13/pflag v1.0.5/go.mod h1:McXfInJRrz4CZXVZOBLb0bTZqETkiAhM9Iw0y3An2Bg= github.com/sqids/sqids-go v0.4.1 h1:eQKYzmAZbLlRwHeHYPF35QhgxwZHLnlmVj9AkIj/rrw= github.com/sqids/sqids-go v0.4.1/go.mod h1:EMwHuPQgSNFS0A49jESTfIQS+066XQTVhukrzEPScl8= -github.com/stretchr/testify v1.8.4 h1:CcVxjf3Q8PM0mHUKJCdn+eZZtm5yQwehR5yeSVQQcUk= -github.com/stretchr/testify v1.8.4/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo= -github.com/xo/terminfo v0.0.0-20210125001918-ca9a967f8778 h1:QldyIu/L63oPpyvQmHgvgickp1Yw510KJOqX7H24mg8= -github.com/xo/terminfo v0.0.0-20210125001918-ca9a967f8778/go.mod h1:2MuV+tbUrU1zIOPMxZ5EncGwgmMJsa+9ucAQZXxsObs= +golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.14.0 h1:Vz7Qs629MkJkGyHxUlRHizWJRG2j8fbQKjELVSNhy7Q= golang.org/x/sys v0.14.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= @@ -53,7 +63,6 @@ gopkg.in/check.v1 v1.0.0-20190902080502-41f04d3bba15 h1:YR8cESwS4TdDjEe65xsg0ogR gopkg.in/check.v1 v1.0.0-20190902080502-41f04d3bba15/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/yaml.v2 v2.2.2 h1:ZCJp+EgiOT7lHqUV2J862kp8Qj64Jo6az82+3Td9dZw= gopkg.in/yaml.v2 v2.2.2/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= -gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= gorm.io/driver/sqlite v1.5.4 h1:IqXwXi8M/ZlPzH/947tn5uik3aYQslP9BVveoax0nV0= gorm.io/driver/sqlite v1.5.4/go.mod h1:qxAuCol+2r6PannQDpOP1FP6ag3mKi4esLnB/jHed+4= diff --git a/main.go b/main.go index 28f9102..8ca6132 100644 --- a/main.go +++ b/main.go @@ -1,15 +1,18 @@ package main import ( - "fmt" - "os" - - "git.mlow.ca/mlow/lmcli/pkg/cli" + "git.mlow.ca/mlow/lmcli/pkg/lmcli" + "git.mlow.ca/mlow/lmcli/pkg/cmd" ) func main() { - if err := cli.Execute(); err != nil { - fmt.Fprintln(os.Stderr, err.Error()) - os.Exit(1) + ctx, err := lmcli.NewContext() + if err != nil { + lmcli.Fatal("%v\n", err) + } + + root := cmd.RootCmd(ctx) + if err := root.Execute(); err != nil { + lmcli.Fatal("%v\n", err) } } diff --git a/pkg/cli/cli.go b/pkg/cli/cli.go deleted file mode 100644 index 8ddd3cd..0000000 --- a/pkg/cli/cli.go +++ /dev/null @@ -1,32 +0,0 @@ -package cli - -import ( - "fmt" - "os" -) - -var config *Config -var store *Store - -func init() { - var err error - - config, err = NewConfig() - if err != nil { - Fatal("%v\n", err) - } - - store, err = NewStore() - if err != nil { - Fatal("%v\n", err) - } -} - -func Fatal(format string, args ...any) { - fmt.Fprintf(os.Stderr, format, args...) - os.Exit(1) -} - -func Warn(format string, args ...any) { - fmt.Fprintf(os.Stderr, format, args...) -} diff --git a/pkg/cli/cmd.go b/pkg/cli/cmd.go deleted file mode 100644 index de00a46..0000000 --- a/pkg/cli/cmd.go +++ /dev/null @@ -1,722 +0,0 @@ -package cli - -import ( - "fmt" - "os" - "slices" - "strings" - "time" - - "github.com/spf13/cobra" -) - -var ( - maxTokens int - model string - systemPrompt string - systemPromptFile string -) - -const ( - // Limit number of conversations shown with `ls`, without --all - LS_LIMIT int = 5 -) - -func init() { - inputCmds := []*cobra.Command{newCmd, promptCmd, replyCmd, retryCmd, continueCmd, editCmd} - for _, cmd := range inputCmds { - cmd.Flags().IntVar(&maxTokens, "length", *config.OpenAI.DefaultMaxLength, "Maximum response tokens") - cmd.Flags().StringVar(&model, "model", *config.OpenAI.DefaultModel, "Which model to use model") - cmd.Flags().StringVar(&systemPrompt, "system-prompt", *config.ModelDefaults.SystemPrompt, "System prompt") - cmd.Flags().StringVar(&systemPromptFile, "system-prompt-file", "", "A path to a file containing the system prompt") - cmd.MarkFlagsMutuallyExclusive("system-prompt", "system-prompt-file") - } - - listCmd.Flags().Int("count", LS_LIMIT, fmt.Sprintf("Number of conversations to list")) - listCmd.Flags().Bool("all", false, "List all conversations") - - renameCmd.Flags().Bool("generate", false, "Generate a conversation title") - editCmd.Flags().Int("offset", 1, "Offset from the last reply to edit (Default: edit your last message, assuming there's an assistant reply)") - - rootCmd.AddCommand( - cloneCmd, - continueCmd, - editCmd, - listCmd, - newCmd, - promptCmd, - renameCmd, - replyCmd, - retryCmd, - rmCmd, - viewCmd, - ) -} - -func Execute() error { - return rootCmd.Execute() -} - -func getSystemPrompt() string { - if systemPromptFile != "" { - content, err := FileContents(systemPromptFile) - if err != nil { - Fatal("Could not read file contents at %s: %v\n", systemPromptFile, err) - } - return content - } - return systemPrompt -} - -// fetchAndShowCompletion prompts the LLM with the given messages and streams -// the response to stdout. Returns all model reply messages. -func fetchAndShowCompletion(messages []Message) ([]Message, error) { - content := make(chan string) // receives the reponse from LLM - defer close(content) - - // render all content received over the channel - go ShowDelayedContent(content) - - var replies []Message - response, err := CreateChatCompletionStream(model, messages, maxTokens, content, &replies) - if response != "" { - // there was some content, so break to a new line after it - fmt.Println() - - if err != nil { - Warn("Received partial response. Error: %v\n", err) - err = nil - } - } - - return replies, err -} - -// lookupConversation either returns the conversation found by the -// short name or exits the program -func lookupConversation(shortName string) *Conversation { - c, err := store.ConversationByShortName(shortName) - if err != nil { - Fatal("Could not lookup conversation: %v\n", err) - } - if c.ID == 0 { - Fatal("Conversation not found with short name: %s\n", shortName) - } - return c -} - -func lookupConversationE(shortName string) (*Conversation, error) { - c, err := store.ConversationByShortName(shortName) - if err != nil { - return nil, fmt.Errorf("Could not lookup conversation: %v", err) - } - if c.ID == 0 { - return nil, fmt.Errorf("Conversation not found with short name: %s", shortName) - } - return c, nil -} - -// handleConversationReply handles sending messages to an existing -// conversation, optionally persisting both the sent replies and responses. -func handleConversationReply(c *Conversation, persist bool, toSend ...Message) { - existing, err := store.Messages(c) - if err != nil { - Fatal("Could not retrieve messages for conversation: %s\n", c.Title) - } - - if persist { - for _, message := range toSend { - err = store.SaveMessage(&message) - if err != nil { - Warn("Could not save %s message: %v\n", message.Role, err) - } - } - } - - allMessages := append(existing, toSend...) - - RenderConversation(allMessages, true) - - // render a message header with no contents - (&Message{Role: MessageRoleAssistant}).RenderTTY() - - replies, err := fetchAndShowCompletion(allMessages) - if err != nil { - Fatal("Error fetching LLM response: %v\n", err) - } - - if persist { - for _, reply := range replies { - reply.ConversationID = c.ID - - err = store.SaveMessage(&reply) - if err != nil { - Warn("Could not save reply: %v\n", err) - } - } - } -} - -// inputFromArgsOrEditor returns either the provided input from the args slice -// (joined with spaces), or if len(args) is 0, opens an editor and returns -// whatever input was provided there. placeholder is a string which populates -// the editor and gets stripped from the final output. -func inputFromArgsOrEditor(args []string, placeholder string, existingMessage string) (message string) { - var err error - if len(args) == 0 { - message, err = InputFromEditor(placeholder, "message.*.md", existingMessage) - if err != nil { - Fatal("Failed to get input: %v\n", err) - } - } else { - message = strings.Trim(strings.Join(args, " "), " \t\n") - } - return -} - -var rootCmd = &cobra.Command{ - Use: "lmcli [flags]", - Long: `lmcli - Large Language Model CLI`, - SilenceErrors: true, - SilenceUsage: true, - Run: func(cmd *cobra.Command, args []string) { - cmd.Usage() - }, -} - -var listCmd = &cobra.Command{ - Use: "list", - Aliases: []string{"ls"}, - Short: "List conversations", - Long: `List conversations in order of recent activity`, - Run: func(cmd *cobra.Command, args []string) { - conversations, err := store.Conversations() - if err != nil { - Fatal("Could not fetch conversations.\n") - return - } - - type Category struct { - name string - cutoff time.Duration - } - - type ConversationLine struct { - timeSinceReply time.Duration - formatted string - } - - now := time.Now() - - midnight := time.Date(now.Year(), now.Month(), now.Day(), 0, 0, 0, 0, now.Location()) - monthStart := time.Date(now.Year(), now.Month(), 1, 0, 0, 0, 0, now.Location()) - dayOfWeek := int(now.Weekday()) - categories := []Category{ - {"today", now.Sub(midnight)}, - {"yesterday", now.Sub(midnight.AddDate(0, 0, -1))}, - {"this week", now.Sub(midnight.AddDate(0, 0, -dayOfWeek))}, - {"last week", now.Sub(midnight.AddDate(0, 0, -(dayOfWeek + 7)))}, - {"this month", now.Sub(monthStart)}, - {"last month", now.Sub(monthStart.AddDate(0, -1, 0))}, - {"2 months ago", now.Sub(monthStart.AddDate(0, -2, 0))}, - {"3 months ago", now.Sub(monthStart.AddDate(0, -3, 0))}, - {"4 months ago", now.Sub(monthStart.AddDate(0, -4, 0))}, - {"5 months ago", now.Sub(monthStart.AddDate(0, -5, 0))}, - {"older", now.Sub(time.Time{})}, - } - categorized := map[string][]ConversationLine{} - - all, _ := cmd.Flags().GetBool("all") - count, _ := cmd.Flags().GetInt("count") - - for _, conversation := range conversations { - lastMessage, err := store.LastMessage(&conversation) - if lastMessage == nil || err != nil { - continue - } - - messageAge := now.Sub(lastMessage.CreatedAt) - - var category string - for _, c := range categories { - if messageAge < c.cutoff { - category = c.name - break - } - } - - formatted := fmt.Sprintf( - "%s - %s - %s", - conversation.ShortName.String, - humanTimeElapsedSince(messageAge), - conversation.Title, - ) - - categorized[category] = append( - categorized[category], - ConversationLine{messageAge, formatted}, - ) - } - - var conversationsPrinted int - outer: - for _, category := range categories { - conversationLines, ok := categorized[category.name] - if !ok { - continue - } - - slices.SortFunc(conversationLines, func(a, b ConversationLine) int { - return int(a.timeSinceReply - b.timeSinceReply) - }) - - fmt.Printf("%s:\n", category.name) - for _, conv := range conversationLines { - if conversationsPrinted >= count && !all { - fmt.Printf("%d remaining message(s), use --all to view.\n", len(conversations)-conversationsPrinted) - break outer - } - - fmt.Printf(" %s\n", conv.formatted) - conversationsPrinted++ - } - } - }, -} - -var rmCmd = &cobra.Command{ - Use: "rm ...", - Short: "Remove conversations", - Long: `Remove conversations by their short names.`, - Args: func(cmd *cobra.Command, args []string) error { - argCount := 1 - if err := cobra.MinimumNArgs(argCount)(cmd, args); err != nil { - return err - } - return nil - }, - Run: func(cmd *cobra.Command, args []string) { - var toRemove []*Conversation - for _, shortName := range args { - conversation := lookupConversation(shortName) - toRemove = append(toRemove, conversation) - } - var errors []error - for _, c := range toRemove { - err := store.DeleteConversation(c) - if err != nil { - errors = append(errors, fmt.Errorf("Could not remove conversation %s: %v", c.ShortName.String, err)) - } - } - for _, err := range errors { - fmt.Fprintln(os.Stderr, err.Error()) - } - if len(errors) > 0 { - os.Exit(1) - } - }, - ValidArgsFunction: func(cmd *cobra.Command, args []string, toComplete string) ([]string, cobra.ShellCompDirective) { - compMode := cobra.ShellCompDirectiveNoFileComp - var completions []string - outer: - for _, completion := range store.ConversationShortNameCompletions(toComplete) { - parts := strings.Split(completion, "\t") - for _, arg := range args { - if parts[0] == arg { - continue outer - } - } - completions = append(completions, completion) - } - return completions, compMode - }, -} - -var cloneCmd = &cobra.Command{ - Use: "clone ", - Short: "Clone conversations", - Long: `Clones the provided conversation.`, - Args: func(cmd *cobra.Command, args []string) error { - argCount := 1 - if err := cobra.MinimumNArgs(argCount)(cmd, args); err != nil { - return err - } - return nil - }, - RunE: func(cmd *cobra.Command, args []string) error { - shortName := args[0] - toClone, err := lookupConversationE(shortName) - if err != nil { - return err - } - - messagesToCopy, err := store.Messages(toClone) - if err != nil { - return fmt.Errorf("Could not retrieve messages for conversation: %s", toClone.ShortName.String) - } - - clone := &Conversation{ - Title: toClone.Title + " - Clone", - } - if err := store.SaveConversation(clone); err != nil { - return fmt.Errorf("Cloud not create clone: %s", err) - } - - var errors []error - messageCnt := 0 - for _, message := range messagesToCopy { - newMessage := message - newMessage.ConversationID = clone.ID - newMessage.ID = 0 - if err := store.SaveMessage(&newMessage); err != nil { - errors = append(errors, err) - } else { - messageCnt++ - } - } - - if len(errors) > 0 { - return fmt.Errorf("Messages failed to be cloned: %v", errors) - } - - fmt.Printf("Cloned %d messages to: %s\n", messageCnt, clone.Title) - return nil - }, - ValidArgsFunction: func(cmd *cobra.Command, args []string, toComplete string) ([]string, cobra.ShellCompDirective) { - compMode := cobra.ShellCompDirectiveNoFileComp - if len(args) != 0 { - return nil, compMode - } - return store.ConversationShortNameCompletions(toComplete), compMode - }, -} - -var viewCmd = &cobra.Command{ - Use: "view ", - Short: "View messages in a conversation", - Long: `Finds a conversation by its short name and displays its contents.`, - Args: func(cmd *cobra.Command, args []string) error { - argCount := 1 - if err := cobra.MinimumNArgs(argCount)(cmd, args); err != nil { - return err - } - return nil - }, - Run: func(cmd *cobra.Command, args []string) { - shortName := args[0] - conversation := lookupConversation(shortName) - - messages, err := store.Messages(conversation) - if err != nil { - Fatal("Could not retrieve messages for conversation: %s\n", conversation.Title) - } - - RenderConversation(messages, false) - }, - ValidArgsFunction: func(cmd *cobra.Command, args []string, toComplete string) ([]string, cobra.ShellCompDirective) { - compMode := cobra.ShellCompDirectiveNoFileComp - if len(args) != 0 { - return nil, compMode - } - return store.ConversationShortNameCompletions(toComplete), compMode - }, -} - -var renameCmd = &cobra.Command{ - Use: "rename [title]", - Short: "Rename a conversation", - Long: `Renames a conversation, either with the provided title or by generating a new name.`, - Args: func(cmd *cobra.Command, args []string) error { - argCount := 1 - if err := cobra.MinimumNArgs(argCount)(cmd, args); err != nil { - return err - } - return nil - }, - Run: func(cmd *cobra.Command, args []string) { - shortName := args[0] - conversation := lookupConversation(shortName) - var err error - - generate, _ := cmd.Flags().GetBool("generate") - var title string - if generate { - title, err = conversation.GenerateTitle() - if err != nil { - Fatal("Could not generate conversation title: %v\n", err) - } - } else { - if len(args) < 2 { - Fatal("Conversation title not provided.\n") - } - title = strings.Join(args[1:], " ") - } - - conversation.Title = title - err = store.SaveConversation(conversation) - if err != nil { - Warn("Could not save conversation with new title: %v\n", err) - } - }, - ValidArgsFunction: func(cmd *cobra.Command, args []string, toComplete string) ([]string, cobra.ShellCompDirective) { - compMode := cobra.ShellCompDirectiveNoFileComp - if len(args) != 0 { - return nil, compMode - } - return store.ConversationShortNameCompletions(toComplete), compMode - }, -} - -var replyCmd = &cobra.Command{ - Use: "reply [message]", - Short: "Reply to a conversation", - Long: `Sends a reply to conversation and writes the response to stdout.`, - Args: func(cmd *cobra.Command, args []string) error { - argCount := 1 - if err := cobra.MinimumNArgs(argCount)(cmd, args); err != nil { - return err - } - return nil - }, - Run: func(cmd *cobra.Command, args []string) { - shortName := args[0] - conversation := lookupConversation(shortName) - - reply := inputFromArgsOrEditor(args[1:], "# How would you like to reply?\n", "") - if reply == "" { - Fatal("No reply was provided.\n") - } - - handleConversationReply(conversation, true, Message{ - ConversationID: conversation.ID, - Role: MessageRoleUser, - OriginalContent: reply, - }) - }, - ValidArgsFunction: func(cmd *cobra.Command, args []string, toComplete string) ([]string, cobra.ShellCompDirective) { - compMode := cobra.ShellCompDirectiveNoFileComp - if len(args) != 0 { - return nil, compMode - } - return store.ConversationShortNameCompletions(toComplete), compMode - }, -} - -var newCmd = &cobra.Command{ - Use: "new [message]", - Short: "Start a new conversation", - Long: `Start a new conversation with the Large Language Model.`, - Run: func(cmd *cobra.Command, args []string) { - messageContents := inputFromArgsOrEditor(args, "# What would you like to say?\n", "") - if messageContents == "" { - Fatal("No message was provided.\n") - } - - conversation := &Conversation{} - err := store.SaveConversation(conversation) - if err != nil { - Fatal("Could not save new conversation: %v\n", err) - } - - messages := []Message{ - { - ConversationID: conversation.ID, - Role: MessageRoleSystem, - OriginalContent: getSystemPrompt(), - }, - { - ConversationID: conversation.ID, - Role: MessageRoleUser, - OriginalContent: messageContents, - }, - } - - handleConversationReply(conversation, true, messages...) - - title, err := conversation.GenerateTitle() - if err != nil { - Warn("Could not generate title for conversation: %v\n", err) - } - - conversation.Title = title - - err = store.SaveConversation(conversation) - if err != nil { - Warn("Could not save conversation after generating title: %v\n", err) - } - }, -} - -var promptCmd = &cobra.Command{ - Use: "prompt [message]", - Short: "Do a one-shot prompt", - Long: `Prompt the Large Language Model and get a response.`, - Run: func(cmd *cobra.Command, args []string) { - message := inputFromArgsOrEditor(args, "# What would you like to say?\n", "") - if message == "" { - Fatal("No message was provided.\n") - } - - messages := []Message{ - { - Role: MessageRoleSystem, - OriginalContent: getSystemPrompt(), - }, - { - Role: MessageRoleUser, - OriginalContent: message, - }, - } - - _, err := fetchAndShowCompletion(messages) - if err != nil { - Fatal("Error fetching LLM response: %v\n", err) - } - }, -} - -var retryCmd = &cobra.Command{ - Use: "retry ", - Short: "Retry the last user reply in a conversation", - Long: `Re-prompt the conversation up to the last user response. Can be used to regenerate the last assistant reply, or simply generate one if an error occurred.`, - Args: func(cmd *cobra.Command, args []string) error { - argCount := 1 - if err := cobra.MinimumNArgs(argCount)(cmd, args); err != nil { - return err - } - return nil - }, - Run: func(cmd *cobra.Command, args []string) { - shortName := args[0] - conversation := lookupConversation(shortName) - - messages, err := store.Messages(conversation) - if err != nil { - Fatal("Could not retrieve messages for conversation: %s\n", conversation.Title) - } - - // walk backwards through the conversation and delete messages, break - // when we find the latest user response - for i := len(messages) - 1; i >= 0; i-- { - if messages[i].Role == MessageRoleUser { - break - } - - err = store.DeleteMessage(&messages[i]) - if err != nil { - Warn("Could not delete previous reply: %v\n", err) - } - } - - handleConversationReply(conversation, true) - }, - ValidArgsFunction: func(cmd *cobra.Command, args []string, toComplete string) ([]string, cobra.ShellCompDirective) { - compMode := cobra.ShellCompDirectiveNoFileComp - if len(args) != 0 { - return nil, compMode - } - return store.ConversationShortNameCompletions(toComplete), compMode - }, -} - -var continueCmd = &cobra.Command{ - Use: "continue ", - Short: "Continue a conversation from the last message", - Long: `Re-prompt the conversation with all existing prompts. Useful if a reply was cut short.`, - Args: func(cmd *cobra.Command, args []string) error { - argCount := 1 - if err := cobra.MinimumNArgs(argCount)(cmd, args); err != nil { - return err - } - return nil - }, - Run: func(cmd *cobra.Command, args []string) { - shortName := args[0] - conversation := lookupConversation(shortName) - handleConversationReply(conversation, true) - }, - ValidArgsFunction: func(cmd *cobra.Command, args []string, toComplete string) ([]string, cobra.ShellCompDirective) { - compMode := cobra.ShellCompDirectiveNoFileComp - if len(args) != 0 { - return nil, compMode - } - return store.ConversationShortNameCompletions(toComplete), compMode - }, -} - -var editCmd = &cobra.Command{ - Use: "edit ", - Short: "Edit the last user reply in a conversation", - Args: func(cmd *cobra.Command, args []string) error { - argCount := 1 - if err := cobra.MinimumNArgs(argCount)(cmd, args); err != nil { - return err - } - return nil - }, - Run: func(cmd *cobra.Command, args []string) { - shortName := args[0] - conversation := lookupConversation(shortName) - - messages, err := store.Messages(conversation) - if err != nil { - Fatal("Could not retrieve messages for conversation: %s\n", conversation.Title) - } - - offset, _ := cmd.Flags().GetInt("offset") - if offset < 0 { - offset = -offset - } - - if offset > len(messages) - 1 { - Fatal("Offset %d is before the start of the conversation\n", offset) - } - - desiredIdx := len(messages) - 1 - offset - - // walk backwards through the conversation deleting messages until and - // including the last user message - toRemove := []Message{} - var toEdit *Message - for i := len(messages) - 1; i >= 0; i-- { - if i == desiredIdx { - toEdit = &messages[i] - } - toRemove = append(toRemove, messages[i]) - messages = messages[:i] - if toEdit != nil { - break - } - } - - existingContents := toEdit.OriginalContent - - newContents := inputFromArgsOrEditor(args[1:], "# Save when finished editing\n", existingContents) - switch newContents { - case existingContents: - Fatal("No edits were made.\n") - case "": - Fatal("No message was provided.\n") - } - - for _, message := range toRemove { - err = store.DeleteMessage(&message) - if err != nil { - Warn("Could not delete message: %v\n", err) - } - } - - handleConversationReply(conversation, true, Message{ - ConversationID: conversation.ID, - Role: MessageRoleUser, - OriginalContent: newContents, - }) - }, - ValidArgsFunction: func(cmd *cobra.Command, args []string, toComplete string) ([]string, cobra.ShellCompDirective) { - compMode := cobra.ShellCompDirectiveNoFileComp - if len(args) != 0 { - return nil, compMode - } - return store.ConversationShortNameCompletions(toComplete), compMode - }, -} diff --git a/pkg/cli/conversation.go b/pkg/cli/conversation.go deleted file mode 100644 index 3ea3078..0000000 --- a/pkg/cli/conversation.go +++ /dev/null @@ -1,67 +0,0 @@ -package cli - -import ( - "fmt" - "strings" -) - -type MessageRole string - -const ( - MessageRoleUser MessageRole = "user" - MessageRoleAssistant MessageRole = "assistant" - MessageRoleSystem MessageRole = "system" -) - -// FriendlyRole returns a human friendly signifier for the message's role. -func (m *Message) FriendlyRole() string { - var friendlyRole string - switch m.Role { - case MessageRoleUser: - friendlyRole = "You" - case MessageRoleSystem: - friendlyRole = "System" - case MessageRoleAssistant: - friendlyRole = "Assistant" - default: - friendlyRole = string(m.Role) - } - return friendlyRole -} - -func (c *Conversation) GenerateTitle() (string, error) { - messages, err := store.Messages(c) - if err != nil { - return "", err - } - - const header = "Generate a concise 4-5 word title for the conversation below." - prompt := fmt.Sprintf("%s\n\n---\n\n%s", header, formatForExternalPrompting(messages, false)) - - generateRequest := []Message{ - { - Role: MessageRoleUser, - OriginalContent: prompt, - }, - } - - model := "gpt-3.5-turbo" // use cheap model to generate title - response, err := CreateChatCompletion(model, generateRequest, 25, nil) - if err != nil { - return "", err - } - - return response, nil -} - -func formatForExternalPrompting(messages []Message, system bool) string { - sb := strings.Builder{} - for _, message := range messages { - if message.Role == MessageRoleSystem && !system { - continue - } - sb.WriteString(fmt.Sprintf("<%s>\n", message.FriendlyRole())) - sb.WriteString(fmt.Sprintf("\"\"\"\n%s\n\"\"\"\n\n", message.OriginalContent)) - } - return sb.String() -} diff --git a/pkg/cli/functions.go b/pkg/cli/functions.go deleted file mode 100644 index ea8a9c3..0000000 --- a/pkg/cli/functions.go +++ /dev/null @@ -1,582 +0,0 @@ -package cli - -import ( - "database/sql" - "encoding/json" - "fmt" - "os" - "path/filepath" - "strings" - - openai "github.com/sashabaranov/go-openai" -) - -type FunctionResult struct { - Message string `json:"message"` - Result any `json:"result,omitempty"` -} - -type FunctionParameter struct { - Type string `json:"type"` // "string", "integer", "boolean" - Description string `json:"description"` - Enum []string `json:"enum,omitempty"` -} - -type FunctionParameters struct { - Type string `json:"type"` // "object" - Properties map[string]FunctionParameter `json:"properties"` - Required []string `json:"required,omitempty"` // required function parameter names -} - -type AvailableTool struct { - openai.Tool - // The tool's implementation. Returns a string, as tool call results - // are treated as normal messages with string contents. - Impl func(arguments map[string]interface{}) (string, error) -} - -const ( - READ_DIR_DESCRIPTION = `Return the contents of the CWD (current working directory). - -Results are returned as JSON in the following format: -{ - "message": "success", // if successful, or a different message indicating failure - // result may be an empty array if there are no files in the directory - "result": [ - {"name": "a_file", "type": "file", "size": 123}, - {"name": "a_directory/", "type": "dir", "size": 11}, - ... // more files or directories - ] -} - -For files, size represents the size (in bytes) of the file. -For directories, size represents the number of entries in that directory.` - - READ_FILE_DESCRIPTION = `Read the contents of a text file relative to the current working directory. - -Each line of the file is prefixed with its line number and a tabs (\t) to make -it make it easier to see which lines to change for other modifications. - -Example result: -{ - "message": "success", // if successful, or a different message indicating failure - "result": "1\tthe contents\n2\tof the file\n" -}` - - WRITE_FILE_DESCRIPTION = `Write the provided contents to a file relative to the current working directory. - -Note: only use this tool when you've been explicitly asked to create or write to a file. - -When using this function, you do not need to share the content you intend to write with the user first. - -Example result: -{ - "message": "success", // if successful, or a different message indicating failure -}` - - FILE_INSERT_LINES_DESCRIPTION = `Insert lines into a file, must specify path. - -Make sure your inserts match the flow and indentation of surrounding content.` - - FILE_REPLACE_LINES_DESCRIPTION = `Replace or remove a range of lines within a file, must specify path. - -Useful for re-writing snippets/blocks of code or entire functions. - -Be cautious with your edits. When replacing, ensure the replacement content matches the flow and indentation of surrounding content.` -) - -var AvailableTools = map[string]AvailableTool{ - "read_dir": { - Tool: openai.Tool{Type: "function", Function: openai.FunctionDefinition{ - Name: "read_dir", - Description: READ_DIR_DESCRIPTION, - Parameters: FunctionParameters{ - Type: "object", - Properties: map[string]FunctionParameter{ - "relative_dir": { - Type: "string", - Description: "If set, read the contents of a directory relative to the current one.", - }, - }, - }, - }}, - Impl: func(args map[string]interface{}) (string, error) { - var relativeDir string - tmp, ok := args["relative_dir"] - if ok { - relativeDir, ok = tmp.(string) - if !ok { - return "", fmt.Errorf("Invalid relative_dir in function arguments: %v", tmp) - } - } - return ReadDir(relativeDir), nil - }, - }, - "read_file": { - Tool: openai.Tool{Type: "function", Function: openai.FunctionDefinition{ - Name: "read_file", - Description: READ_FILE_DESCRIPTION, - Parameters: FunctionParameters{ - Type: "object", - Properties: map[string]FunctionParameter{ - "path": { - Type: "string", - Description: "Path to a file within the current working directory to read.", - }, - }, - Required: []string{"path"}, - }, - }}, - Impl: func(args map[string]interface{}) (string, error) { - tmp, ok := args["path"] - if !ok { - return "", fmt.Errorf("Path parameter to read_file was not included.") - } - path, ok := tmp.(string) - if !ok { - return "", fmt.Errorf("Invalid path in function arguments: %v", tmp) - } - return ReadFile(path), nil - }, - }, - "write_file": { - Tool: openai.Tool{Type: "function", Function: openai.FunctionDefinition{ - Name: "write_file", - Description: WRITE_FILE_DESCRIPTION, - Parameters: FunctionParameters{ - Type: "object", - Properties: map[string]FunctionParameter{ - "path": { - Type: "string", - Description: "Path to a file within the current working directory to write to.", - }, - "content": { - Type: "string", - Description: "The content to write to the file. Overwrites any existing content!", - }, - }, - Required: []string{"path", "content"}, - }, - }}, - Impl: func(args map[string]interface{}) (string, error) { - tmp, ok := args["path"] - if !ok { - return "", fmt.Errorf("Path parameter to write_file was not included.") - } - path, ok := tmp.(string) - if !ok { - return "", fmt.Errorf("Invalid path in function arguments: %v", tmp) - } - tmp, ok = args["content"] - if !ok { - return "", fmt.Errorf("Content parameter to write_file was not included.") - } - content, ok := tmp.(string) - if !ok { - return "", fmt.Errorf("Invalid content in function arguments: %v", tmp) - } - return WriteFile(path, content), nil - }, - }, - "file_insert_lines": { - Tool: openai.Tool{Type: "function", Function: openai.FunctionDefinition{ - Name: "file_insert_lines", - Description: FILE_INSERT_LINES_DESCRIPTION, - Parameters: FunctionParameters{ - Type: "object", - Properties: map[string]FunctionParameter{ - "path": { - Type: "string", - Description: "Path of the file to be modified, relative to the current working directory.", - }, - "position": { - Type: "integer", - Description: `Which line to insert content *before*.`, - }, - "content": { - Type: "string", - Description: `The content to insert.`, - }, - }, - Required: []string{"path", "position", "content"}, - }, - }}, - Impl: func(args map[string]interface{}) (string, error) { - tmp, ok := args["path"] - if !ok { - return "", fmt.Errorf("path parameter to write_file was not included.") - } - path, ok := tmp.(string) - if !ok { - return "", fmt.Errorf("Invalid path in function arguments: %v", tmp) - } - var position int - tmp, ok = args["position"] - if ok { - tmp, ok := tmp.(float64) - if !ok { - return "", fmt.Errorf("Invalid position in function arguments: %v", tmp) - } - position = int(tmp) - } - var content string - tmp, ok = args["content"] - if ok { - content, ok = tmp.(string) - if !ok { - return "", fmt.Errorf("Invalid content in function arguments: %v", tmp) - } - } - return FileInsertLines(path, position, content), nil - }, - }, - "file_replace_lines": { - Tool: openai.Tool{Type: "function", Function: openai.FunctionDefinition{ - Name: "file_replace_lines", - Description: FILE_REPLACE_LINES_DESCRIPTION, - Parameters: FunctionParameters{ - Type: "object", - Properties: map[string]FunctionParameter{ - "path": { - Type: "string", - Description: "Path of the file to be modified, relative to the current working directory.", - }, - "start_line": { - Type: "integer", - Description: `Line number which specifies the start of the replacement range (inclusive).`, - }, - "end_line": { - Type: "integer", - Description: `Line number which specifies the end of the replacement range (inclusive). If unset, range extends to end of file.`, - }, - "content": { - Type: "string", - Description: `Content to replace specified range. Omit to remove the specified range.`, - }, - }, - Required: []string{"path", "start_line"}, - }, - }}, - Impl: func(args map[string]interface{}) (string, error) { - tmp, ok := args["path"] - if !ok { - return "", fmt.Errorf("path parameter to write_file was not included.") - } - path, ok := tmp.(string) - if !ok { - return "", fmt.Errorf("Invalid path in function arguments: %v", tmp) - } - var start_line int - tmp, ok = args["start_line"] - if ok { - tmp, ok := tmp.(float64) - if !ok { - return "", fmt.Errorf("Invalid start_line in function arguments: %v", tmp) - } - start_line = int(tmp) - } - var end_line int - tmp, ok = args["end_line"] - if ok { - tmp, ok := tmp.(float64) - if !ok { - return "", fmt.Errorf("Invalid end_line in function arguments: %v", tmp) - } - end_line = int(tmp) - } - var content string - tmp, ok = args["content"] - if ok { - content, ok = tmp.(string) - if !ok { - return "", fmt.Errorf("Invalid content in function arguments: %v", tmp) - } - } - - return FileReplaceLines(path, start_line, end_line, content), nil - }, - }, -} - -func resultToJson(result FunctionResult) string { - if result.Message == "" { - // When message not supplied, assume success - result.Message = "success" - } - - jsonBytes, err := json.Marshal(result) - if err != nil { - fmt.Printf("Could not marshal FunctionResult to JSON: %v\n", err) - } - return string(jsonBytes) -} - -// ExecuteToolCalls handles the execution of all tool_calls provided, and -// returns their results formatted as []Message(s) with role: 'tool' and. -func ExecuteToolCalls(toolCalls []openai.ToolCall) ([]Message, error) { - var toolResults []Message - for _, toolCall := range toolCalls { - if toolCall.Type != "function" { - // unsupported tool type - continue - } - - tool, ok := AvailableTools[toolCall.Function.Name] - if !ok { - return nil, fmt.Errorf("Requested tool '%s' does not exist. Hallucination?", toolCall.Function.Name) - } - - var functionArgs map[string]interface{} - err := json.Unmarshal([]byte(toolCall.Function.Arguments), &functionArgs) - if err != nil { - return nil, fmt.Errorf("Could not unmarshal tool arguments. Malformed JSON? Error: %v", err) - } - - // TODO: ability to silence this - fmt.Fprintf(os.Stderr, "INFO: Executing tool '%s' with args %s\n", toolCall.Function.Name, toolCall.Function.Arguments) - - // Execute the tool - toolResult, err := tool.Impl(functionArgs) - if err != nil { - // This can happen if the model missed or supplied invalid tool args - return nil, fmt.Errorf("Tool '%s' error: %v\n", toolCall.Function.Name, err) - } - - toolResults = append(toolResults, Message{ - Role: "tool", - OriginalContent: toolResult, - ToolCallID: sql.NullString{String: toolCall.ID, Valid: true}, - // name is not required since the introduction of ToolCallID - // hypothesis: by setting it, we inform the model of what a - // function's purpose was if future requests omit the function - // definition - }) - } - return toolResults, nil -} - -// isPathContained attempts to verify whether `path` is the same as or -// contained within `directory`. It is overly cautious, returning false even if -// `path` IS contained within `directory`, but the two paths use different -// casing, and we happen to be on a case-insensitive filesystem. -// This is ultimately to attempt to stop an LLM from going outside of where I -// tell it to. Additional layers of security should be considered.. run in a -// VM/container. -func isPathContained(directory string, path string) (bool, error) { - // Clean and resolve symlinks for both paths - path, err := filepath.Abs(path) - if err != nil { - return false, err - } - - // check if path exists - _, err = os.Stat(path) - if err != nil { - if !os.IsNotExist(err) { - return false, fmt.Errorf("Could not stat path: %v", err) - } - } else { - path, err = filepath.EvalSymlinks(path) - if err != nil { - return false, err - } - } - - directory, err = filepath.Abs(directory) - if err != nil { - return false, err - } - directory, err = filepath.EvalSymlinks(directory) - if err != nil { - return false, err - } - - // Case insensitive checks - if !strings.EqualFold(path, directory) && - !strings.HasPrefix(strings.ToLower(path), strings.ToLower(directory)+string(os.PathSeparator)) { - return false, nil - } - - return true, nil -} - -func isPathWithinCWD(path string) (bool, *FunctionResult) { - cwd, err := os.Getwd() - if err != nil { - return false, &FunctionResult{Message: "Failed to determine current working directory"} - } - if ok, err := isPathContained(cwd, path); !ok { - if err != nil { - return false, &FunctionResult{Message: fmt.Sprintf("Could not determine whether path '%s' is within the current working directory: %s", path, err.Error())} - } - return false, &FunctionResult{Message: fmt.Sprintf("Path '%s' is not within the current working directory", path)} - } - return true, nil -} - -func ReadDir(path string) string { - // TODO(?): implement whitelist - list of directories which model is allowed to work in - if path == "" { - path = "." - } - ok, res := isPathWithinCWD(path) - if !ok { - return resultToJson(*res) - } - - files, err := os.ReadDir(path) - if err != nil { - return resultToJson(FunctionResult{ - Message: err.Error(), - }) - } - - var dirContents []map[string]interface{} - for _, f := range files { - info, _ := f.Info() - - name := f.Name() - if strings.HasPrefix(name, ".") { - // skip hidden files - continue - } - - entryType := "file" - size := info.Size() - - if info.IsDir() { - name += "/" - entryType = "dir" - subdirfiles, _ := os.ReadDir(filepath.Join(".", path, info.Name())) - size = int64(len(subdirfiles)) - } - - dirContents = append(dirContents, map[string]interface{}{ - "name": name, - "type": entryType, - "size": size, - }) - } - - return resultToJson(FunctionResult{Result: dirContents}) -} - -func ReadFile(path string) string { - ok, res := isPathWithinCWD(path) - if !ok { - return resultToJson(*res) - } - data, err := os.ReadFile(path) - if err != nil { - return resultToJson(FunctionResult{Message: fmt.Sprintf("Could not read path: %s", err.Error())}) - } - - lines := strings.Split(string(data), "\n") - content := strings.Builder{} - for i, line := range lines { - content.WriteString(fmt.Sprintf("%d\t%s\n", i+1, line)) - } - - return resultToJson(FunctionResult{ - Result: content.String(), - }) -} - -func WriteFile(path string, content string) string { - ok, res := isPathWithinCWD(path) - if !ok { - return resultToJson(*res) - } - err := os.WriteFile(path, []byte(content), 0644) - if err != nil { - return resultToJson(FunctionResult{Message: fmt.Sprintf("Could not write to path: %s", err.Error())}) - } - return resultToJson(FunctionResult{}) -} - -func FileInsertLines(path string, position int, content string) string { - ok, res := isPathWithinCWD(path) - if !ok { - return resultToJson(*res) - } - - // Read the existing file's content - data, err := os.ReadFile(path) - if err != nil { - if !os.IsNotExist(err) { - return resultToJson(FunctionResult{Message: fmt.Sprintf("Could not read path: %s", err.Error())}) - } - _, err = os.Create(path) - if err != nil { - return resultToJson(FunctionResult{Message: fmt.Sprintf("Could not create new file: %s", err.Error())}) - } - data = []byte{} - } - - if position < 1 { - return resultToJson(FunctionResult{Message: "start_line cannot be less than 1"}) - } - - lines := strings.Split(string(data), "\n") - contentLines := strings.Split(strings.Trim(content, "\n"), "\n") - - before := lines[:position-1] - after := lines[position-1:] - lines = append(before, append(contentLines, after...)...) - - newContent := strings.Join(lines, "\n") - - // Join the lines and write back to the file - err = os.WriteFile(path, []byte(newContent), 0644) - if err != nil { - return resultToJson(FunctionResult{Message: fmt.Sprintf("Could not write to path: %s", err.Error())}) - } - - return resultToJson(FunctionResult{Result: newContent}) -} - -func FileReplaceLines(path string, startLine int, endLine int, content string) string { - ok, res := isPathWithinCWD(path) - if !ok { - return resultToJson(*res) - } - - // Read the existing file's content - data, err := os.ReadFile(path) - if err != nil { - if !os.IsNotExist(err) { - return resultToJson(FunctionResult{Message: fmt.Sprintf("Could not read path: %s", err.Error())}) - } - _, err = os.Create(path) - if err != nil { - return resultToJson(FunctionResult{Message: fmt.Sprintf("Could not create new file: %s", err.Error())}) - } - data = []byte{} - } - - if startLine < 1 { - return resultToJson(FunctionResult{Message: "start_line cannot be less than 1"}) - } - - lines := strings.Split(string(data), "\n") - contentLines := strings.Split(strings.Trim(content, "\n"), "\n") - - if endLine == 0 || endLine > len(lines) { - endLine = len(lines) - } - - before := lines[:startLine-1] - after := lines[endLine:] - - lines = append(before, append(contentLines, after...)...) - newContent := strings.Join(lines, "\n") - - // Join the lines and write back to the file - err = os.WriteFile(path, []byte(newContent), 0644) - if err != nil { - return resultToJson(FunctionResult{Message: fmt.Sprintf("Could not write to path: %s", err.Error())}) - } - - return resultToJson(FunctionResult{Result: newContent}) - -} diff --git a/pkg/cli/openai.go b/pkg/cli/openai.go deleted file mode 100644 index ce111b8..0000000 --- a/pkg/cli/openai.go +++ /dev/null @@ -1,187 +0,0 @@ -package cli - -import ( - "context" - "database/sql" - "encoding/json" - "errors" - "fmt" - "io" - "strings" - - openai "github.com/sashabaranov/go-openai" -) - -func CreateChatCompletionRequest(model string, messages []Message, maxTokens int) openai.ChatCompletionRequest { - chatCompletionMessages := []openai.ChatCompletionMessage{} - for _, m := range messages { - message := openai.ChatCompletionMessage{ - Role: string(m.Role), - Content: m.OriginalContent, - } - if m.ToolCallID.Valid { - message.ToolCallID = m.ToolCallID.String - } - if m.ToolCalls.Valid { - // unmarshal directly into chatMessage.ToolCalls - err := json.Unmarshal([]byte(m.ToolCalls.String), &message.ToolCalls) - if err != nil { - // TODO: handle, this shouldn't really happen since - // we only save the successfully marshal'd data to database - fmt.Printf("Error unmarshalling the tool_calls JSON: %v\n", err) - } - } - chatCompletionMessages = append(chatCompletionMessages, message) - } - - request := openai.ChatCompletionRequest{ - Model: model, - Messages: chatCompletionMessages, - MaxTokens: maxTokens, - N: 1, // limit responses to 1 "choice". we use choices[0] to reference it - } - - var tools []openai.Tool - for _, t := range config.OpenAI.EnabledTools { - tool, ok := AvailableTools[t] - if ok { - tools = append(tools, tool.Tool) - } - } - - if len(tools) > 0 { - request.Tools = tools - request.ToolChoice = "auto" - } - - return request -} - -// CreateChatCompletion submits a Chat Completion API request and returns the -// response. CreateChatCompletion will recursively call itself in the case of -// tool calls, until a response is received with the final user-facing output. -func CreateChatCompletion(model string, messages []Message, maxTokens int, replies *[]Message) (string, error) { - client := openai.NewClient(*config.OpenAI.APIKey) - req := CreateChatCompletionRequest(model, messages, maxTokens) - resp, err := client.CreateChatCompletion(context.Background(), req) - if err != nil { - return "", err - } - - choice := resp.Choices[0] - - if len(choice.Message.ToolCalls) > 0 { - // Append the assistant's reply with its request for tool calls - toolCallJson, _ := json.Marshal(choice.Message.ToolCalls) - assistantReply := Message{ - Role: MessageRoleAssistant, - ToolCalls: sql.NullString{String: string(toolCallJson), Valid: true}, - } - - toolReplies, err := ExecuteToolCalls(choice.Message.ToolCalls) - if err != nil { - return "", err - } - - if replies != nil { - *replies = append(append(*replies, assistantReply), toolReplies...) - } - - messages = append(append(messages, assistantReply), toolReplies...) - // Recurse into CreateChatCompletion with the tool call replies added - // to the original messages - return CreateChatCompletion(model, messages, maxTokens, replies) - } - - if replies != nil { - *replies = append(*replies, Message{ - Role: MessageRoleAssistant, - OriginalContent: choice.Message.Content, - }) - } - - // Return the user-facing message. - return choice.Message.Content, nil -} - -// CreateChatCompletionStream submits a streaming Chat Completion API request -// and both returns and streams the response to the provided output channel. -// May return a partial response if an error occurs mid-stream. -func CreateChatCompletionStream(model string, messages []Message, maxTokens int, output chan<- string, replies *[]Message) (string, error) { - client := openai.NewClient(*config.OpenAI.APIKey) - req := CreateChatCompletionRequest(model, messages, maxTokens) - - stream, err := client.CreateChatCompletionStream(context.Background(), req) - if err != nil { - return "", err - } - defer stream.Close() - - content := strings.Builder{} - toolCalls := []openai.ToolCall{} - - // Iterate stream segments - for { - response, e := stream.Recv() - if errors.Is(e, io.EOF) { - break - } - - if e != nil { - err = e - break - } - - delta := response.Choices[0].Delta - if len(delta.ToolCalls) > 0 { - // Construct streamed tool_call arguments - for _, tc := range delta.ToolCalls { - if tc.Index == nil { - return "", fmt.Errorf("Unexpected nil index for streamed tool call.") - } - if len(toolCalls) <= *tc.Index { - toolCalls = append(toolCalls, tc) - } else { - toolCalls[*tc.Index].Function.Arguments += tc.Function.Arguments - } - } - } else { - output <- delta.Content - content.WriteString(delta.Content) - } - } - - if len(toolCalls) > 0 { - // Append the assistant's reply with its request for tool calls - toolCallJson, _ := json.Marshal(toolCalls) - - assistantReply := Message{ - Role: MessageRoleAssistant, - OriginalContent: content.String(), - ToolCalls: sql.NullString{String: string(toolCallJson), Valid: true}, - } - - toolReplies, err := ExecuteToolCalls(toolCalls) - if err != nil { - return "", err - } - - if replies != nil { - *replies = append(append(*replies, assistantReply), toolReplies...) - } - - // Recurse into CreateChatCompletionStream with the tool call replies - // added to the original messages - messages = append(append(messages, assistantReply), toolReplies...) - return CreateChatCompletionStream(model, messages, maxTokens, output, replies) - } - - if replies != nil { - *replies = append(*replies, Message{ - Role: MessageRoleAssistant, - OriginalContent: content.String(), - }) - } - - return content.String(), err -} diff --git a/pkg/cli/store.go b/pkg/cli/store.go deleted file mode 100644 index 5ffbacd..0000000 --- a/pkg/cli/store.go +++ /dev/null @@ -1,141 +0,0 @@ -package cli - -import ( - "errors" - "database/sql" - "fmt" - "os" - "path/filepath" - "strings" - "time" - - sqids "github.com/sqids/sqids-go" - "gorm.io/driver/sqlite" - "gorm.io/gorm" -) - -type Store struct { - db *gorm.DB - sqids *sqids.Sqids -} - -type Message struct { - ID uint `gorm:"primaryKey"` - ConversationID uint `gorm:"foreignKey:ConversationID"` - Conversation Conversation - OriginalContent string - Role MessageRole // one of: 'system', 'user', 'assistant', 'tool' - CreatedAt time.Time - ToolCallID sql.NullString - ToolCalls sql.NullString // a json-encoded array of tool calls from the model -} - -type Conversation struct { - ID uint `gorm:"primaryKey"` - ShortName sql.NullString - Title string -} - -func dataDir() string { - var dataDir string - - xdgDataHome := os.Getenv("XDG_DATA_HOME") - if xdgDataHome != "" { - dataDir = filepath.Join(xdgDataHome, "lmcli") - } else { - userHomeDir, _ := os.UserHomeDir() - dataDir = filepath.Join(userHomeDir, ".local/share/lmcli") - } - - os.MkdirAll(dataDir, 0755) - return dataDir -} - -func NewStore() (*Store, error) { - databaseFile := filepath.Join(dataDir(), "conversations.db") - db, err := gorm.Open(sqlite.Open(databaseFile), &gorm.Config{}) - if err != nil { - return nil, fmt.Errorf("Error establishing connection to store: %v", err) - } - - models := []any{ - &Conversation{}, - &Message{}, - } - - for _, x := range models { - err := db.AutoMigrate(x) - if err != nil { - return nil, fmt.Errorf("Could not perform database migrations: %v", err) - } - } - - _sqids, _ := sqids.New(sqids.Options{MinLength: 4}) - return &Store{db, _sqids}, nil -} - -func (s *Store) SaveConversation(conversation *Conversation) error { - err := s.db.Save(&conversation).Error - if err != nil { - return err - } - - if !conversation.ShortName.Valid { - shortName, _ := s.sqids.Encode([]uint64{uint64(conversation.ID)}) - conversation.ShortName = sql.NullString{String: shortName, Valid: true} - err = s.db.Save(&conversation).Error - } - - return err -} - -func (s *Store) DeleteConversation(conversation *Conversation) error { - s.db.Where("conversation_id = ?", conversation.ID).Delete(&Message{}) - return s.db.Delete(&conversation).Error -} - -func (s *Store) SaveMessage(message *Message) error { - return s.db.Create(message).Error -} - -func (s *Store) DeleteMessage(message *Message) error { - return s.db.Delete(&message).Error -} - -func (s *Store) Conversations() ([]Conversation, error) { - var conversations []Conversation - err := s.db.Find(&conversations).Error - return conversations, err -} - -func (s *Store) ConversationShortNameCompletions(shortName string) []string { - var completions []string - conversations, _ := s.Conversations() // ignore error for completions - for _, conversation := range conversations { - if shortName == "" || strings.HasPrefix(conversation.ShortName.String, shortName) { - completions = append(completions, fmt.Sprintf("%s\t%s", conversation.ShortName.String, conversation.Title)) - } - } - return completions -} - -func (s *Store) ConversationByShortName(shortName string) (*Conversation, error) { - if shortName == "" { - return nil, errors.New("shortName is empty") - } - var conversation Conversation - err := s.db.Where("short_name = ?", shortName).Find(&conversation).Error - return &conversation, err -} - -func (s *Store) Messages(conversation *Conversation) ([]Message, error) { - var messages []Message - err := s.db.Where("conversation_id = ?", conversation.ID).Find(&messages).Error - return messages, err -} - -func (s *Store) LastMessage(conversation *Conversation) (*Message, error) { - var message Message - err := s.db.Where("conversation_id = ?", conversation.ID).Last(&message).Error - return &message, err -} diff --git a/pkg/cli/tty.go b/pkg/cli/tty.go deleted file mode 100644 index ab275cc..0000000 --- a/pkg/cli/tty.go +++ /dev/null @@ -1,113 +0,0 @@ -package cli - -import ( - "fmt" - "os" - "time" - - "github.com/alecthomas/chroma/v2/quick" - "github.com/gookit/color" -) - -// ShowWaitAnimation "draws" an animated ellipses to stdout until something is -// received on the signal channel. An empty string sent to the channel to -// noftify the caller that the animation has completed (carriage returned). -func ShowWaitAnimation(signal chan any) { - animationStep := 0 - for { - select { - case _ = <-signal: - fmt.Print("\r") - signal <- "" - return - default: - modSix := animationStep % 6 - if modSix == 3 || modSix == 0 { - fmt.Print("\r") - } - if modSix < 3 { - fmt.Print(".") - } else { - fmt.Print(" ") - } - animationStep++ - time.Sleep(250 * time.Millisecond) - } - } -} - -// ShowDelayedContent displays a waiting animation to stdout while waiting -// for content to be received on the provided channel. As soon as any (possibly -// chunked) content is received on the channel, the waiting animation is -// replaced by the content. -// Blocks until the channel is closed. -func ShowDelayedContent(content <-chan string) { - waitSignal := make(chan any) - go ShowWaitAnimation(waitSignal) - - firstChunk := true - for chunk := range content { - if firstChunk { - // notify wait animation that we've received data - waitSignal <- "" - // wait for signal that wait animation has completed - <-waitSignal - firstChunk = false - } - fmt.Print(chunk) - } -} - -// RenderConversation renders the given messages to TTY, with optional space -// for a subsequent message. spaceForResponse controls how many '\n' characters -// are printed immediately after the final message (1 if false, 2 if true) -func RenderConversation(messages []Message, spaceForResponse bool) { - l := len(messages) - for i, message := range messages { - message.RenderTTY() - if i < l-1 || spaceForResponse { - // print an additional space before the next message - fmt.Println() - } - } -} - -// HighlightMarkdown applies syntax highlighting to the provided markdown text -// and writes it to stdout. -func HighlightMarkdown(markdownText string) error { - return quick.Highlight(os.Stdout, markdownText, "md", *config.Chroma.Formatter, *config.Chroma.Style) -} - -func (m *Message) RenderTTY() { - var messageAge string - if m.CreatedAt.IsZero() { - messageAge = "now" - } else { - now := time.Now() - messageAge = humanTimeElapsedSince(now.Sub(m.CreatedAt)) - } - - var roleStyle color.Style - switch m.Role { - case MessageRoleSystem: - roleStyle = color.Style{color.HiRed} - case MessageRoleUser: - roleStyle = color.Style{color.HiGreen} - case MessageRoleAssistant: - roleStyle = color.Style{color.HiBlue} - default: - roleStyle = color.Style{color.FgWhite} - } - roleStyle.Add(color.Bold) - - headerColor := color.FgYellow - separator := headerColor.Sprint("===") - timestamp := headerColor.Sprint(messageAge) - role := roleStyle.Sprint(m.FriendlyRole()) - - fmt.Printf("%s %s - %s %s\n\n", separator, role, timestamp, separator) - if m.OriginalContent != "" { - HighlightMarkdown(m.OriginalContent) - fmt.Println() - } -} diff --git a/pkg/cmd/clone.go b/pkg/cmd/clone.go new file mode 100644 index 0000000..a32024f --- /dev/null +++ b/pkg/cmd/clone.go @@ -0,0 +1,72 @@ +package cmd + +import ( + "fmt" + + cmdutil "git.mlow.ca/mlow/lmcli/pkg/cmd/util" + "git.mlow.ca/mlow/lmcli/pkg/lmcli" + "git.mlow.ca/mlow/lmcli/pkg/lmcli/model" + "github.com/spf13/cobra" +) + +func CloneCmd(ctx *lmcli.Context) *cobra.Command { + cmd := &cobra.Command{ + Use: "clone ", + Short: "Clone conversations", + Long: `Clones the provided conversation.`, + Args: func(cmd *cobra.Command, args []string) error { + argCount := 1 + if err := cobra.MinimumNArgs(argCount)(cmd, args); err != nil { + return err + } + return nil + }, + RunE: func(cmd *cobra.Command, args []string) error { + shortName := args[0] + toClone, err := cmdutil.LookupConversationE(ctx, shortName) + if err != nil { + return err + } + + messagesToCopy, err := ctx.Store.Messages(toClone) + if err != nil { + return fmt.Errorf("Could not retrieve messages for conversation: %s", toClone.ShortName.String) + } + + clone := &model.Conversation{ + Title: toClone.Title + " - Clone", + } + if err := ctx.Store.SaveConversation(clone); err != nil { + return fmt.Errorf("Cloud not create clone: %s", err) + } + + var errors []error + messageCnt := 0 + for _, message := range messagesToCopy { + newMessage := message + newMessage.ConversationID = clone.ID + newMessage.ID = 0 + if err := ctx.Store.SaveMessage(&newMessage); err != nil { + errors = append(errors, err) + } else { + messageCnt++ + } + } + + if len(errors) > 0 { + return fmt.Errorf("Messages failed to be cloned: %v", errors) + } + + fmt.Printf("Cloned %d messages to: %s\n", messageCnt, clone.Title) + return nil + }, + ValidArgsFunction: func(cmd *cobra.Command, args []string, toComplete string) ([]string, cobra.ShellCompDirective) { + compMode := cobra.ShellCompDirectiveNoFileComp + if len(args) != 0 { + return nil, compMode + } + return ctx.Store.ConversationShortNameCompletions(toComplete), compMode + }, + } + return cmd +} diff --git a/pkg/cmd/cmd.go b/pkg/cmd/cmd.go new file mode 100644 index 0000000..e3963d5 --- /dev/null +++ b/pkg/cmd/cmd.go @@ -0,0 +1,93 @@ +package cmd + +import ( + "strings" + + "git.mlow.ca/mlow/lmcli/pkg/lmcli" + "git.mlow.ca/mlow/lmcli/pkg/util" + "github.com/spf13/cobra" +) + +var ( + systemPromptFile string +) + +func RootCmd(ctx *lmcli.Context) *cobra.Command { + var root = &cobra.Command{ + Use: "lmcli [flags]", + Long: `lmcli - Large Language Model CLI`, + SilenceErrors: true, + SilenceUsage: true, + Run: func(cmd *cobra.Command, args []string) { + cmd.Usage() + }, + } + + continueCmd := ContinueCmd(ctx) + cloneCmd := CloneCmd(ctx) + editCmd := EditCmd(ctx) + listCmd := ListCmd(ctx) + newCmd := NewCmd(ctx) + promptCmd := PromptCmd(ctx) + renameCmd := RenameCmd(ctx) + replyCmd := ReplyCmd(ctx) + retryCmd := RetryCmd(ctx) + rmCmd := RemoveCmd(ctx) + viewCmd := ViewCmd(ctx) + + inputCmds := []*cobra.Command{newCmd, promptCmd, replyCmd, retryCmd, continueCmd, editCmd} + for _, cmd := range inputCmds { + cmd.Flags().StringVar(ctx.Config.Defaults.Model, "model", *ctx.Config.Defaults.Model, "Which model to use") + cmd.Flags().IntVar(ctx.Config.Defaults.MaxTokens, "length", *ctx.Config.Defaults.MaxTokens, "Maximum response tokens") + cmd.Flags().StringVar(ctx.Config.Defaults.SystemPrompt, "system-prompt", *ctx.Config.Defaults.SystemPrompt, "System prompt") + cmd.Flags().StringVar(&systemPromptFile, "system-prompt-file", "", "A path to a file containing the system prompt") + cmd.MarkFlagsMutuallyExclusive("system-prompt", "system-prompt-file") + } + + renameCmd.Flags().Bool("generate", false, "Generate a conversation title") + + root.AddCommand( + cloneCmd, + continueCmd, + editCmd, + listCmd, + newCmd, + promptCmd, + renameCmd, + replyCmd, + retryCmd, + rmCmd, + viewCmd, + ) + + return root +} + +func getSystemPrompt(ctx *lmcli.Context) string { + if systemPromptFile != "" { + content, err := util.ReadFileContents(systemPromptFile) + if err != nil { + lmcli.Fatal("Could not read file contents at %s: %v\n", systemPromptFile, err) + } + return content + } + return *ctx.Config.Defaults.SystemPrompt +} + +// inputFromArgsOrEditor returns either the provided input from the args slice +// (joined with spaces), or if len(args) is 0, opens an editor and returns +// whatever input was provided there. placeholder is a string which populates +// the editor and gets stripped from the final output. +func inputFromArgsOrEditor(args []string, placeholder string, existingMessage string) (message string) { + var err error + if len(args) == 0 { + message, err = util.InputFromEditor(placeholder, "message.*.md", existingMessage) + if err != nil { + lmcli.Fatal("Failed to get input: %v\n", err) + } + } else { + message = strings.Join(args, " ") + } + message = strings.Trim(message, " \t\n") + return +} diff --git a/pkg/cmd/continue.go b/pkg/cmd/continue.go new file mode 100644 index 0000000..c164fea --- /dev/null +++ b/pkg/cmd/continue.go @@ -0,0 +1,72 @@ +package cmd + +import ( + "fmt" + "strings" + + cmdutil "git.mlow.ca/mlow/lmcli/pkg/cmd/util" + "git.mlow.ca/mlow/lmcli/pkg/lmcli" + "git.mlow.ca/mlow/lmcli/pkg/lmcli/model" + "github.com/spf13/cobra" +) + +func ContinueCmd(ctx *lmcli.Context) *cobra.Command { + cmd := &cobra.Command{ + Use: "continue ", + Short: "Continue a conversation from the last message", + Long: `Re-prompt the conversation with all existing prompts. Useful if a reply was cut short.`, + Args: func(cmd *cobra.Command, args []string) error { + argCount := 1 + if err := cobra.MinimumNArgs(argCount)(cmd, args); err != nil { + return err + } + return nil + }, + RunE: func(cmd *cobra.Command, args []string) error { + shortName := args[0] + conversation := cmdutil.LookupConversation(ctx, shortName) + + messages, err := ctx.Store.Messages(conversation) + if err != nil { + return fmt.Errorf("could not retrieve conversation messages: %v", err) + } + + if len(messages) < 2 { + return fmt.Errorf("conversation expected to have at least 2 messages") + } + + lastMessage := &messages[len(messages)-1] + if lastMessage.Role != model.MessageRoleAssistant { + return fmt.Errorf("the last message in the conversation is not an assistant message") + } + + // Output the contents of the last message so far + fmt.Print(lastMessage.Content) + + // Submit the LLM request, allowing it to continue the last message + continuedOutput, err := cmdutil.FetchAndShowCompletion(ctx, messages) + if err != nil { + return fmt.Errorf("error fetching LLM response: %v", err) + } + + // Append the new response to the original message + lastMessage.Content += strings.TrimRight(continuedOutput[0].Content, "\n\t ") + + // Update the original message + err = ctx.Store.UpdateMessage(lastMessage) + if err != nil { + return fmt.Errorf("could not update the last message: %v", err) + } + + return nil + }, + ValidArgsFunction: func(cmd *cobra.Command, args []string, toComplete string) ([]string, cobra.ShellCompDirective) { + compMode := cobra.ShellCompDirectiveNoFileComp + if len(args) != 0 { + return nil, compMode + } + return ctx.Store.ConversationShortNameCompletions(toComplete), compMode + }, + } + return cmd +} diff --git a/pkg/cmd/edit.go b/pkg/cmd/edit.go new file mode 100644 index 0000000..b8bcbaf --- /dev/null +++ b/pkg/cmd/edit.go @@ -0,0 +1,100 @@ +package cmd + +import ( + "fmt" + + cmdutil "git.mlow.ca/mlow/lmcli/pkg/cmd/util" + "git.mlow.ca/mlow/lmcli/pkg/lmcli" + "git.mlow.ca/mlow/lmcli/pkg/lmcli/model" + "github.com/spf13/cobra" +) + +func EditCmd(ctx *lmcli.Context) *cobra.Command { + cmd := &cobra.Command{ + Use: "edit ", + Short: "Edit the last user reply in a conversation", + Args: func(cmd *cobra.Command, args []string) error { + argCount := 1 + if err := cobra.MinimumNArgs(argCount)(cmd, args); err != nil { + return err + } + return nil + }, + RunE: func(cmd *cobra.Command, args []string) error { + shortName := args[0] + conversation := cmdutil.LookupConversation(ctx, shortName) + + messages, err := ctx.Store.Messages(conversation) + if err != nil { + return fmt.Errorf("Could not retrieve messages for conversation: %s", conversation.Title) + } + + offset, _ := cmd.Flags().GetInt("offset") + if offset < 0 { + offset = -offset + } + + if offset > len(messages)-1 { + return fmt.Errorf("Offset %d is before the start of the conversation.", offset) + } + + desiredIdx := len(messages) - 1 - offset + + // walk backwards through the conversation deleting messages until and + // including the last user message + toRemove := []model.Message{} + var toEdit *model.Message + for i := len(messages) - 1; i >= 0; i-- { + if i == desiredIdx { + toEdit = &messages[i] + } + toRemove = append(toRemove, messages[i]) + messages = messages[:i] + if toEdit != nil { + break + } + } + + newContents := inputFromArgsOrEditor(args[1:], "# Save when finished editing\n", toEdit.Content) + switch newContents { + case toEdit.Content: + return fmt.Errorf("No edits were made.") + case "": + return fmt.Errorf("No message was provided.") + } + + role, _ := cmd.Flags().GetString("role") + if role == "" { + role = string(toEdit.Role) + } else if role != string(model.MessageRoleUser) && role != string(model.MessageRoleAssistant) { + return fmt.Errorf("Invalid role specified. Please use 'user' or 'assistant'.") + } + + for _, message := range toRemove { + err = ctx.Store.DeleteMessage(&message) + if err != nil { + lmcli.Warn("Could not delete message: %v\n", err) + } + } + + cmdutil.HandleConversationReply(ctx, conversation, true, model.Message{ + ConversationID: conversation.ID, + Role: model.MessageRole(role), + Content: newContents, + }) + return nil + }, + ValidArgsFunction: func(cmd *cobra.Command, args []string, toComplete string) ([]string, cobra.ShellCompDirective) { + compMode := cobra.ShellCompDirectiveNoFileComp + if len(args) != 0 { + return nil, compMode + } + return ctx.Store.ConversationShortNameCompletions(toComplete), compMode + }, + } + + cmd.Flags().Int("offset", 1, "Offset from the last message to edit") + cmd.Flags().StringP("role", "r", "", "Role of the edited message (user or assistant)") + + return cmd +} diff --git a/pkg/cmd/list.go b/pkg/cmd/list.go new file mode 100644 index 0000000..ef15b84 --- /dev/null +++ b/pkg/cmd/list.go @@ -0,0 +1,122 @@ +package cmd + +import ( + "fmt" + "slices" + "time" + + "git.mlow.ca/mlow/lmcli/pkg/lmcli" + "git.mlow.ca/mlow/lmcli/pkg/util" + "github.com/spf13/cobra" +) + +const ( + LS_COUNT int = 5 +) + +func ListCmd(ctx *lmcli.Context) *cobra.Command { + cmd := &cobra.Command{ + Use: "list", + Aliases: []string{"ls"}, + Short: "List conversations", + Long: `List conversations in order of recent activity`, + RunE: func(cmd *cobra.Command, args []string) error { + conversations, err := ctx.Store.Conversations() + if err != nil { + return fmt.Errorf("Could not fetch conversations: %v", err) + } + + type Category struct { + name string + cutoff time.Duration + } + + type ConversationLine struct { + timeSinceReply time.Duration + formatted string + } + + now := time.Now() + + midnight := time.Date(now.Year(), now.Month(), now.Day(), 0, 0, 0, 0, now.Location()) + monthStart := time.Date(now.Year(), now.Month(), 1, 0, 0, 0, 0, now.Location()) + dayOfWeek := int(now.Weekday()) + categories := []Category{ + {"today", now.Sub(midnight)}, + {"yesterday", now.Sub(midnight.AddDate(0, 0, -1))}, + {"this week", now.Sub(midnight.AddDate(0, 0, -dayOfWeek))}, + {"last week", now.Sub(midnight.AddDate(0, 0, -(dayOfWeek + 7)))}, + {"this month", now.Sub(monthStart)}, + {"last month", now.Sub(monthStart.AddDate(0, -1, 0))}, + {"2 months ago", now.Sub(monthStart.AddDate(0, -2, 0))}, + {"3 months ago", now.Sub(monthStart.AddDate(0, -3, 0))}, + {"4 months ago", now.Sub(monthStart.AddDate(0, -4, 0))}, + {"5 months ago", now.Sub(monthStart.AddDate(0, -5, 0))}, + {"older", now.Sub(time.Time{})}, + } + categorized := map[string][]ConversationLine{} + + all, _ := cmd.Flags().GetBool("all") + + for _, conversation := range conversations { + lastMessage, err := ctx.Store.LastMessage(&conversation) + if lastMessage == nil || err != nil { + continue + } + + messageAge := now.Sub(lastMessage.CreatedAt) + + var category string + for _, c := range categories { + if messageAge < c.cutoff { + category = c.name + break + } + } + + formatted := fmt.Sprintf( + "%s - %s - %s", + conversation.ShortName.String, + util.HumanTimeElapsedSince(messageAge), + conversation.Title, + ) + + categorized[category] = append( + categorized[category], + ConversationLine{messageAge, formatted}, + ) + } + + count, _ := cmd.Flags().GetInt("count") + var conversationsPrinted int + outer: + for _, category := range categories { + conversationLines, ok := categorized[category.name] + if !ok { + continue + } + + slices.SortFunc(conversationLines, func(a, b ConversationLine) int { + return int(a.timeSinceReply - b.timeSinceReply) + }) + + fmt.Printf("%s:\n", category.name) + for _, conv := range conversationLines { + if conversationsPrinted >= count && !all { + fmt.Printf("%d remaining message(s), use --all to view.\n", len(conversations)-conversationsPrinted) + break outer + } + + fmt.Printf(" %s\n", conv.formatted) + conversationsPrinted++ + } + } + return nil + }, + } + + cmd.Flags().Bool("all", false, "Show all conversations") + cmd.Flags().Int("count", LS_COUNT, "How many conversations to show") + + return cmd +} diff --git a/pkg/cmd/new.go b/pkg/cmd/new.go new file mode 100644 index 0000000..6875681 --- /dev/null +++ b/pkg/cmd/new.go @@ -0,0 +1,60 @@ +package cmd + +import ( + "fmt" + + cmdutil "git.mlow.ca/mlow/lmcli/pkg/cmd/util" + "git.mlow.ca/mlow/lmcli/pkg/lmcli" + "git.mlow.ca/mlow/lmcli/pkg/lmcli/model" + "github.com/spf13/cobra" +) + +func NewCmd(ctx *lmcli.Context) *cobra.Command { + cmd := &cobra.Command{ + Use: "new [message]", + Short: "Start a new conversation", + Long: `Start a new conversation with the Large Language Model.`, + RunE: func(cmd *cobra.Command, args []string) error { + messageContents := inputFromArgsOrEditor(args, "# What would you like to say?\n", "") + if messageContents == "" { + return fmt.Errorf("No message was provided.") + } + + conversation := &model.Conversation{} + err := ctx.Store.SaveConversation(conversation) + if err != nil { + return fmt.Errorf("Could not save new conversation: %v", err) + } + + messages := []model.Message{ + { + ConversationID: conversation.ID, + Role: model.MessageRoleSystem, + Content: getSystemPrompt(ctx), + }, + { + ConversationID: conversation.ID, + Role: model.MessageRoleUser, + Content: messageContents, + }, + } + + cmdutil.HandleConversationReply(ctx, conversation, true, messages...) + + title, err := cmdutil.GenerateTitle(ctx, conversation) + if err != nil { + lmcli.Warn("Could not generate title for conversation: %v\n", err) + } + + conversation.Title = title + + err = ctx.Store.SaveConversation(conversation) + if err != nil { + lmcli.Warn("Could not save conversation after generating title: %v\n", err) + } + return nil + }, + } + + return cmd +} diff --git a/pkg/cmd/prompt.go b/pkg/cmd/prompt.go new file mode 100644 index 0000000..4362c29 --- /dev/null +++ b/pkg/cmd/prompt.go @@ -0,0 +1,42 @@ +package cmd + +import ( + "fmt" + + cmdutil "git.mlow.ca/mlow/lmcli/pkg/cmd/util" + "git.mlow.ca/mlow/lmcli/pkg/lmcli" + "git.mlow.ca/mlow/lmcli/pkg/lmcli/model" + "github.com/spf13/cobra" +) + +func PromptCmd(ctx *lmcli.Context) *cobra.Command { + cmd := &cobra.Command{ + Use: "prompt [message]", + Short: "Do a one-shot prompt", + Long: `Prompt the Large Language Model and get a response.`, + RunE: func(cmd *cobra.Command, args []string) error { + message := inputFromArgsOrEditor(args, "# What would you like to say?\n", "") + if message == "" { + return fmt.Errorf("No message was provided.") + } + + messages := []model.Message{ + { + Role: model.MessageRoleSystem, + Content: getSystemPrompt(ctx), + }, + { + Role: model.MessageRoleUser, + Content: message, + }, + } + + _, err := cmdutil.FetchAndShowCompletion(ctx, messages) + if err != nil { + return fmt.Errorf("Error fetching LLM response: %v", err) + } + return nil + }, + } + return cmd +} diff --git a/pkg/cmd/remove.go b/pkg/cmd/remove.go new file mode 100644 index 0000000..12cec51 --- /dev/null +++ b/pkg/cmd/remove.go @@ -0,0 +1,60 @@ +package cmd + +import ( + "fmt" + "strings" + + cmdutil "git.mlow.ca/mlow/lmcli/pkg/cmd/util" + "git.mlow.ca/mlow/lmcli/pkg/lmcli" + "git.mlow.ca/mlow/lmcli/pkg/lmcli/model" + "github.com/spf13/cobra" +) + +func RemoveCmd(ctx *lmcli.Context) *cobra.Command { + cmd := &cobra.Command{ + Use: "rm ...", + Short: "Remove conversations", + Long: `Remove conversations by their short names.`, + Args: func(cmd *cobra.Command, args []string) error { + argCount := 1 + if err := cobra.MinimumNArgs(argCount)(cmd, args); err != nil { + return err + } + return nil + }, + RunE: func(cmd *cobra.Command, args []string) error { + var toRemove []*model.Conversation + for _, shortName := range args { + conversation := cmdutil.LookupConversation(ctx, shortName) + toRemove = append(toRemove, conversation) + } + var errors []error + for _, c := range toRemove { + err := ctx.Store.DeleteConversation(c) + if err != nil { + errors = append(errors, fmt.Errorf("Could not remove conversation %s: %v", c.ShortName.String, err)) + } + } + if len(errors) > 0 { + return fmt.Errorf("Could not remove some conversations: %v", errors) + } + return nil + }, + ValidArgsFunction: func(cmd *cobra.Command, args []string, toComplete string) ([]string, cobra.ShellCompDirective) { + compMode := cobra.ShellCompDirectiveNoFileComp + var completions []string + outer: + for _, completion := range ctx.Store.ConversationShortNameCompletions(toComplete) { + parts := strings.Split(completion, "\t") + for _, arg := range args { + if parts[0] == arg { + continue outer + } + } + completions = append(completions, completion) + } + return completions, compMode + }, + } + return cmd +} diff --git a/pkg/cmd/rename.go b/pkg/cmd/rename.go new file mode 100644 index 0000000..c876f54 --- /dev/null +++ b/pkg/cmd/rename.go @@ -0,0 +1,60 @@ +package cmd + +import ( + "fmt" + "strings" + + "git.mlow.ca/mlow/lmcli/pkg/lmcli" + "github.com/spf13/cobra" + cmdutil "git.mlow.ca/mlow/lmcli/pkg/cmd/util" +) + +func RenameCmd(ctx *lmcli.Context) *cobra.Command { + cmd := &cobra.Command{ + Use: "rename [title]", + Short: "Rename a conversation", + Long: `Renames a conversation, either with the provided title or by generating a new name.`, + Args: func(cmd *cobra.Command, args []string) error { + argCount := 1 + if err := cobra.MinimumNArgs(argCount)(cmd, args); err != nil { + return err + } + return nil + }, + RunE: func(cmd *cobra.Command, args []string) error { + shortName := args[0] + conversation := cmdutil.LookupConversation(ctx, shortName) + var err error + + generate, _ := cmd.Flags().GetBool("generate") + var title string + if generate { + title, err = cmdutil.GenerateTitle(ctx, conversation) + if err != nil { + return fmt.Errorf("Could not generate conversation title: %v", err) + } + } else { + if len(args) < 2 { + return fmt.Errorf("Conversation title not provided.") + } + title = strings.Join(args[1:], " ") + } + + conversation.Title = title + err = ctx.Store.SaveConversation(conversation) + if err != nil { + lmcli.Warn("Could not save conversation with new title: %v\n", err) + } + return nil + }, + ValidArgsFunction: func(cmd *cobra.Command, args []string, toComplete string) ([]string, cobra.ShellCompDirective) { + compMode := cobra.ShellCompDirectiveNoFileComp + if len(args) != 0 { + return nil, compMode + } + return ctx.Store.ConversationShortNameCompletions(toComplete), compMode + }, + } + + return cmd +} diff --git a/pkg/cmd/reply.go b/pkg/cmd/reply.go new file mode 100644 index 0000000..d923aaa --- /dev/null +++ b/pkg/cmd/reply.go @@ -0,0 +1,49 @@ +package cmd + +import ( + "fmt" + + cmdutil "git.mlow.ca/mlow/lmcli/pkg/cmd/util" + "git.mlow.ca/mlow/lmcli/pkg/lmcli" + "git.mlow.ca/mlow/lmcli/pkg/lmcli/model" + "github.com/spf13/cobra" +) + +func ReplyCmd(ctx *lmcli.Context) *cobra.Command { + cmd := &cobra.Command{ + Use: "reply [message]", + Short: "Reply to a conversation", + Long: `Sends a reply to conversation and writes the response to stdout.`, + Args: func(cmd *cobra.Command, args []string) error { + argCount := 1 + if err := cobra.MinimumNArgs(argCount)(cmd, args); err != nil { + return err + } + return nil + }, + RunE: func(cmd *cobra.Command, args []string) error { + shortName := args[0] + conversation := cmdutil.LookupConversation(ctx, shortName) + + reply := inputFromArgsOrEditor(args[1:], "# How would you like to reply?\n", "") + if reply == "" { + return fmt.Errorf("No reply was provided.") + } + + cmdutil.HandleConversationReply(ctx, conversation, true, model.Message{ + ConversationID: conversation.ID, + Role: model.MessageRoleUser, + Content: reply, + }) + return nil + }, + ValidArgsFunction: func(cmd *cobra.Command, args []string, toComplete string) ([]string, cobra.ShellCompDirective) { + compMode := cobra.ShellCompDirectiveNoFileComp + if len(args) != 0 { + return nil, compMode + } + return ctx.Store.ConversationShortNameCompletions(toComplete), compMode + }, + } + return cmd +} diff --git a/pkg/cmd/retry.go b/pkg/cmd/retry.go new file mode 100644 index 0000000..9604830 --- /dev/null +++ b/pkg/cmd/retry.go @@ -0,0 +1,58 @@ +package cmd + +import ( + "fmt" + + cmdutil "git.mlow.ca/mlow/lmcli/pkg/cmd/util" + "git.mlow.ca/mlow/lmcli/pkg/lmcli" + "git.mlow.ca/mlow/lmcli/pkg/lmcli/model" + "github.com/spf13/cobra" +) + +func RetryCmd(ctx *lmcli.Context) *cobra.Command { + cmd := &cobra.Command{ + Use: "retry ", + Short: "Retry the last user reply in a conversation", + Long: `Re-prompt the conversation up to the last user response. Can be used to regenerate the last assistant reply, or simply generate one if an error occurred.`, + Args: func(cmd *cobra.Command, args []string) error { + argCount := 1 + if err := cobra.MinimumNArgs(argCount)(cmd, args); err != nil { + return err + } + return nil + }, + RunE: func(cmd *cobra.Command, args []string) error { + shortName := args[0] + conversation := cmdutil.LookupConversation(ctx, shortName) + + messages, err := ctx.Store.Messages(conversation) + if err != nil { + return fmt.Errorf("Could not retrieve messages for conversation: %s", conversation.Title) + } + + // walk backwards through the conversation and delete messages, break + // when we find the latest user response + for i := len(messages) - 1; i >= 0; i-- { + if messages[i].Role == model.MessageRoleUser { + break + } + + err = ctx.Store.DeleteMessage(&messages[i]) + if err != nil { + lmcli.Warn("Could not delete previous reply: %v\n", err) + } + } + + cmdutil.HandleConversationReply(ctx, conversation, true) + return nil + }, + ValidArgsFunction: func(cmd *cobra.Command, args []string, toComplete string) ([]string, cobra.ShellCompDirective) { + compMode := cobra.ShellCompDirectiveNoFileComp + if len(args) != 0 { + return nil, compMode + } + return ctx.Store.ConversationShortNameCompletions(toComplete), compMode + }, + } + return cmd +} diff --git a/pkg/cmd/util/util.go b/pkg/cmd/util/util.go new file mode 100644 index 0000000..d7d3de4 --- /dev/null +++ b/pkg/cmd/util/util.go @@ -0,0 +1,284 @@ +package util + +import ( + "fmt" + "io" + "os" + "strings" + "time" + + "git.mlow.ca/mlow/lmcli/pkg/lmcli" + "git.mlow.ca/mlow/lmcli/pkg/lmcli/model" + "git.mlow.ca/mlow/lmcli/pkg/lmcli/tools" + "git.mlow.ca/mlow/lmcli/pkg/util" + "github.com/alecthomas/chroma/v2/quick" + "github.com/charmbracelet/lipgloss" +) + +// fetchAndShowCompletion prompts the LLM with the given messages and streams +// the response to stdout. Returns all model reply messages. +func FetchAndShowCompletion(ctx *lmcli.Context, messages []model.Message) ([]model.Message, error) { + content := make(chan string) // receives the reponse from LLM + defer close(content) + + // render all content received over the channel + go ShowDelayedContent(content) + + completionProvider, err := ctx.GetCompletionProvider(*ctx.Config.Defaults.Model) + if err != nil { + return nil, err + } + + var toolBag []model.Tool + for _, toolName := range *ctx.Config.Tools.EnabledTools { + tool, ok := tools.AvailableTools[toolName] + if ok { + toolBag = append(toolBag, tool) + } + } + + requestParams := model.RequestParameters{ + Model: *ctx.Config.Defaults.Model, + MaxTokens: *ctx.Config.Defaults.MaxTokens, + Temperature: *ctx.Config.Defaults.Temperature, + ToolBag: toolBag, + } + + var apiReplies []model.Message + response, err := completionProvider.CreateChatCompletionStream( + requestParams, messages, &apiReplies, content, + ) + if response != "" { + // there was some content, so break to a new line after it + fmt.Println() + + if err != nil { + lmcli.Warn("Received partial response. Error: %v\n", err) + err = nil + } + } + + return apiReplies, err +} + +// lookupConversation either returns the conversation found by the +// short name or exits the program +func LookupConversation(ctx *lmcli.Context, shortName string) *model.Conversation { + c, err := ctx.Store.ConversationByShortName(shortName) + if err != nil { + lmcli.Fatal("Could not lookup conversation: %v\n", err) + } + if c.ID == 0 { + lmcli.Fatal("Conversation not found with short name: %s\n", shortName) + } + return c +} + +func LookupConversationE(ctx *lmcli.Context, shortName string) (*model.Conversation, error) { + c, err := ctx.Store.ConversationByShortName(shortName) + if err != nil { + return nil, fmt.Errorf("Could not lookup conversation: %v", err) + } + if c.ID == 0 { + return nil, fmt.Errorf("Conversation not found with short name: %s", shortName) + } + return c, nil +} + +// handleConversationReply handles sending messages to an existing +// conversation, optionally persisting both the sent replies and responses. +func HandleConversationReply(ctx *lmcli.Context, c *model.Conversation, persist bool, toSend ...model.Message) { + existing, err := ctx.Store.Messages(c) + if err != nil { + lmcli.Fatal("Could not retrieve messages for conversation: %s\n", c.Title) + } + + if persist { + for _, message := range toSend { + err = ctx.Store.SaveMessage(&message) + if err != nil { + lmcli.Warn("Could not save %s message: %v\n", message.Role, err) + } + } + } + + allMessages := append(existing, toSend...) + + RenderConversation(ctx, allMessages, true) + + // render a message header with no contents + RenderMessage(ctx, (&model.Message{Role: model.MessageRoleAssistant})) + + replies, err := FetchAndShowCompletion(ctx, allMessages) + if err != nil { + lmcli.Fatal("Error fetching LLM response: %v\n", err) + } + + if persist { + for _, reply := range replies { + reply.ConversationID = c.ID + + err = ctx.Store.SaveMessage(&reply) + if err != nil { + lmcli.Warn("Could not save reply: %v\n", err) + } + } + } +} + +func FormatForExternalPrompt(messages []model.Message, system bool) string { + sb := strings.Builder{} + for _, message := range messages { + if message.Role != model.MessageRoleUser && (message.Role != model.MessageRoleSystem || !system) { + continue + } + sb.WriteString(fmt.Sprintf("<%s>\n", message.Role.FriendlyRole())) + sb.WriteString(fmt.Sprintf("\"\"\"\n%s\n\"\"\"\n\n", message.Content)) + } + return sb.String() +} + +func GenerateTitle(ctx *lmcli.Context, c *model.Conversation) (string, error) { + messages, err := ctx.Store.Messages(c) + if err != nil { + return "", err + } + + const header = "Generate a concise 4-5 word title for the conversation below." + prompt := fmt.Sprintf("%s\n\n---\n\n%s", header, FormatForExternalPrompt(messages, false)) + + generateRequest := []model.Message{ + { + Role: model.MessageRoleUser, + Content: prompt, + }, + } + + completionProvider, err := ctx.GetCompletionProvider(*ctx.Config.Conversations.TitleGenerationModel) + if err != nil { + return "", err + } + + requestParams := model.RequestParameters{ + Model: *ctx.Config.Conversations.TitleGenerationModel, + MaxTokens: 25, + } + + response, err := completionProvider.CreateChatCompletion(requestParams, generateRequest, nil) + if err != nil { + return "", err + } + + return response, nil +} + +// ShowWaitAnimation prints an animated ellipses to stdout until something is +// received on the signal channel. An empty string sent to the channel to +// noftify the caller that the animation has completed (carriage returned). +func ShowWaitAnimation(signal chan any) { + // Save the current cursor position + fmt.Print("\033[s") + + animationStep := 0 + for { + select { + case _ = <-signal: + // Relmcli the cursor position + fmt.Print("\033[u") + signal <- "" + return + default: + // Move the cursor to the saved position + modSix := animationStep % 6 + if modSix == 3 || modSix == 0 { + fmt.Print("\033[u") + } + if modSix < 3 { + fmt.Print(".") + } else { + fmt.Print(" ") + } + animationStep++ + time.Sleep(250 * time.Millisecond) + } + } +} + +// ShowDelayedContent displays a waiting animation to stdout while waiting +// for content to be received on the provided channel. As soon as any (possibly +// chunked) content is received on the channel, the waiting animation is +// replaced by the content. +// Blocks until the channel is closed. +func ShowDelayedContent(content <-chan string) { + waitSignal := make(chan any) + go ShowWaitAnimation(waitSignal) + + firstChunk := true + for chunk := range content { + if firstChunk { + // notify wait animation that we've received data + waitSignal <- "" + // wait for signal that wait animation has completed + <-waitSignal + firstChunk = false + } + fmt.Print(chunk) + } +} + +// RenderConversation renders the given messages to TTY, with optional space +// for a subsequent message. spaceForResponse controls how many '\n' characters +// are printed immediately after the final message (1 if false, 2 if true) +func RenderConversation(ctx *lmcli.Context, messages []model.Message, spaceForResponse bool) { + l := len(messages) + for i, message := range messages { + RenderMessage(ctx, &message) + if i < l-1 || spaceForResponse { + // print an additional space before the next message + fmt.Println() + } + } +} + +// HighlightMarkdown applies syntax highlighting to the provided markdown text +// and writes it to stdout. +func HighlightMarkdown(w io.Writer, markdownText string, formatter string, style string) error { + return quick.Highlight(w, markdownText, "md", formatter, style) +} + +func RenderMessage(ctx *lmcli.Context, m *model.Message) { + var messageAge string + if m.CreatedAt.IsZero() { + messageAge = "now" + } else { + now := time.Now() + messageAge = util.HumanTimeElapsedSince(now.Sub(m.CreatedAt)) + } + + headerStyle := lipgloss.NewStyle().Bold(true) + + switch m.Role { + case model.MessageRoleSystem: + headerStyle = headerStyle.Foreground(lipgloss.Color("9")) // bright red + case model.MessageRoleUser: + headerStyle = headerStyle.Foreground(lipgloss.Color("10")) // bright green + case model.MessageRoleAssistant: + headerStyle = headerStyle.Foreground(lipgloss.Color("12")) // bright blue + } + + role := headerStyle.Render(m.Role.FriendlyRole()) + + separatorStyle := lipgloss.NewStyle().Foreground(lipgloss.Color("3")) + separator := separatorStyle.Render("===") + timestamp := separatorStyle.Render(messageAge) + + fmt.Printf("%s %s - %s %s\n\n", separator, role, timestamp, separator) + if m.Content != "" { + HighlightMarkdown( + os.Stdout, m.Content, + *ctx.Config.Chroma.Formatter, + *ctx.Config.Chroma.Style, + ) + fmt.Println() + } +} diff --git a/pkg/cmd/view.go b/pkg/cmd/view.go new file mode 100644 index 0000000..5cffc55 --- /dev/null +++ b/pkg/cmd/view.go @@ -0,0 +1,45 @@ +package cmd + +import ( + "fmt" + + cmdutil "git.mlow.ca/mlow/lmcli/pkg/cmd/util" + "git.mlow.ca/mlow/lmcli/pkg/lmcli" + "github.com/spf13/cobra" +) + +func ViewCmd(ctx *lmcli.Context) *cobra.Command { + cmd := &cobra.Command{ + Use: "view ", + Short: "View messages in a conversation", + Long: `Finds a conversation by its short name and displays its contents.`, + Args: func(cmd *cobra.Command, args []string) error { + argCount := 1 + if err := cobra.MinimumNArgs(argCount)(cmd, args); err != nil { + return err + } + return nil + }, + RunE: func(cmd *cobra.Command, args []string) error { + shortName := args[0] + conversation := cmdutil.LookupConversation(ctx, shortName) + + messages, err := ctx.Store.Messages(conversation) + if err != nil { + return fmt.Errorf("Could not retrieve messages for conversation: %s", conversation.Title) + } + + cmdutil.RenderConversation(ctx, messages, false) + return nil + }, + ValidArgsFunction: func(cmd *cobra.Command, args []string, toComplete string) ([]string, cobra.ShellCompDirective) { + compMode := cobra.ShellCompDirectiveNoFileComp + if len(args) != 0 { + return nil, compMode + } + return ctx.Store.ConversationShortNameCompletions(toComplete), compMode + }, + } + + return cmd +} diff --git a/pkg/cli/config.go b/pkg/lmcli/config.go similarity index 51% rename from pkg/cli/config.go rename to pkg/lmcli/config.go index 16f0d75..b5eb5ab 100644 --- a/pkg/cli/config.go +++ b/pkg/lmcli/config.go @@ -1,46 +1,41 @@ -package cli +package lmcli import ( "fmt" "os" - "path/filepath" + "git.mlow.ca/mlow/lmcli/pkg/util" "github.com/go-yaml/yaml" ) type Config struct { - ModelDefaults *struct { - SystemPrompt *string `yaml:"systemPrompt" default:"You are a helpful assistant."` - } `yaml:"modelDefaults"` + Defaults *struct { + SystemPrompt *string `yaml:"systemPrompt" default:"You are a helpful assistant."` + MaxTokens *int `yaml:"maxTokens" default:"256"` + Temperature *float32 `yaml:"temperature" default:"0.7"` + Model *string `yaml:"model" default:"gpt-4"` + } `yaml:"defaults"` + Conversations *struct { + TitleGenerationModel *string `yaml:"titleGenerationModel" default:"gpt-3.5-turbo"` + } `yaml:"conversations"` + Tools *struct { + EnabledTools *[]string `yaml:"enabledTools"` + } `yaml:"tools"` OpenAI *struct { - APIKey *string `yaml:"apiKey" default:"your_key_here"` - DefaultModel *string `yaml:"defaultModel" default:"gpt-4"` - DefaultMaxLength *int `yaml:"defaultMaxLength" default:"256"` - EnabledTools []string `yaml:"enabledTools"` + APIKey *string `yaml:"apiKey" default:"your_key_here"` + Models *[]string `yaml:"models"` } `yaml:"openai"` + Anthropic *struct { + APIKey *string `yaml:"apiKey" default:"your_key_here"` + Models *[]string `yaml:"models"` + } `yaml:"anthropic"` Chroma *struct { Style *string `yaml:"style" default:"onedark"` Formatter *string `yaml:"formatter" default:"terminal16m"` } `yaml:"chroma"` } -func configDir() string { - var configDir string - - xdgConfigHome := os.Getenv("XDG_CONFIG_HOME") - if xdgConfigHome != "" { - configDir = filepath.Join(xdgConfigHome, "lmcli") - } else { - userHomeDir, _ := os.UserHomeDir() - configDir = filepath.Join(userHomeDir, ".config/lmcli") - } - - os.MkdirAll(configDir, 0755) - return configDir -} - -func NewConfig() (*Config, error) { - configFile := filepath.Join(configDir(), "config.yaml") +func NewConfig(configFile string) (*Config, error) { shouldWriteDefaults := false c := &Config{} @@ -54,11 +49,11 @@ func NewConfig() (*Config, error) { yaml.Unmarshal(configBytes, c) } - shouldWriteDefaults = SetStructDefaults(c) + shouldWriteDefaults = util.SetStructDefaults(c) if !configExists || shouldWriteDefaults { if configExists { - fmt.Printf("Saving new defaults to configuration, backing up existing configuration to %s\n", configFile + ".bak") - os.Rename(configFile, configFile + ".bak") + fmt.Printf("Saving new defaults to configuration, backing up existing configuration to %s\n", configFile+".bak") + os.Rename(configFile, configFile+".bak") } fmt.Printf("Writing configuration file to %s\n", configFile) file, err := os.Create(configFile) diff --git a/pkg/lmcli/lmcli.go b/pkg/lmcli/lmcli.go new file mode 100644 index 0000000..ab6ce16 --- /dev/null +++ b/pkg/lmcli/lmcli.go @@ -0,0 +1,97 @@ +package lmcli + +import ( + "fmt" + "os" + "path/filepath" + + "git.mlow.ca/mlow/lmcli/pkg/lmcli/provider" + "git.mlow.ca/mlow/lmcli/pkg/lmcli/provider/anthropic" + "git.mlow.ca/mlow/lmcli/pkg/lmcli/provider/openai" + "gorm.io/driver/sqlite" + "gorm.io/gorm" +) + +type Context struct { + Config Config + Store ConversationStore +} + +func NewContext() (*Context, error) { + configFile := filepath.Join(configDir(), "config.yaml") + config, err := NewConfig(configFile) + if err != nil { + Fatal("%v\n", err) + } + + databaseFile := filepath.Join(dataDir(), "conversations.db") + db, err := gorm.Open(sqlite.Open(databaseFile), &gorm.Config{}) + if err != nil { + return nil, fmt.Errorf("Error establishing connection to store: %v", err) + } + s, err := NewSQLStore(db) + if err != nil { + Fatal("%v\n", err) + } + + return &Context{*config, s}, nil +} + +func (c *Context) GetCompletionProvider(model string) (provider.ChatCompletionClient, error) { + for _, m := range *c.Config.Anthropic.Models { + if m == model { + anthropic := &anthropic.AnthropicClient{ + APIKey: *c.Config.Anthropic.APIKey, + } + return anthropic, nil + } + } + for _, m := range *c.Config.OpenAI.Models { + if m == model { + openai := &openai.OpenAIClient{ + APIKey: *c.Config.OpenAI.APIKey, + } + return openai, nil + } + } + return nil, fmt.Errorf("unknown model: %s", model) +} + +func configDir() string { + var configDir string + + xdgConfigHome := os.Getenv("XDG_CONFIG_HOME") + if xdgConfigHome != "" { + configDir = filepath.Join(xdgConfigHome, "lmcli") + } else { + userHomeDir, _ := os.UserHomeDir() + configDir = filepath.Join(userHomeDir, ".config/lmcli") + } + + os.MkdirAll(configDir, 0755) + return configDir +} + +func dataDir() string { + var dataDir string + + xdgDataHome := os.Getenv("XDG_DATA_HOME") + if xdgDataHome != "" { + dataDir = filepath.Join(xdgDataHome, "lmcli") + } else { + userHomeDir, _ := os.UserHomeDir() + dataDir = filepath.Join(userHomeDir, ".local/share/lmcli") + } + + os.MkdirAll(dataDir, 0755) + return dataDir +} + +func Fatal(format string, args ...any) { + fmt.Fprintf(os.Stderr, format, args...) + os.Exit(1) +} + +func Warn(format string, args ...any) { + fmt.Fprintf(os.Stderr, format, args...) +} diff --git a/pkg/lmcli/model/conversation.go b/pkg/lmcli/model/conversation.go new file mode 100644 index 0000000..5494b90 --- /dev/null +++ b/pkg/lmcli/model/conversation.go @@ -0,0 +1,58 @@ +package model + +import ( + "database/sql" + "time" +) + +type MessageRole string + +const ( + MessageRoleSystem MessageRole = "system" + MessageRoleUser MessageRole = "user" + MessageRoleAssistant MessageRole = "assistant" + MessageRoleToolCall MessageRole = "tool_call" + MessageRoleToolResult MessageRole = "tool_result" +) + +type Message struct { + ID uint `gorm:"primaryKey"` + ConversationID uint `gorm:"foreignKey:ConversationID"` + Content string + Role MessageRole + CreatedAt time.Time + ToolCalls ToolCalls // a json array of tool calls (from the modl) + ToolResults ToolResults // a json array of tool results +} + +type Conversation struct { + ID uint `gorm:"primaryKey"` + ShortName sql.NullString + Title string +} + +type RequestParameters struct { + Model string + MaxTokens int + Temperature float32 + TopP float32 + + SystemPrompt string + ToolBag []Tool +} + +// FriendlyRole returns a human friendly signifier for the message's role. +func (m *MessageRole) FriendlyRole() string { + var friendlyRole string + switch *m { + case MessageRoleUser: + friendlyRole = "You" + case MessageRoleSystem: + friendlyRole = "System" + case MessageRoleAssistant: + friendlyRole = "Assistant" + default: + friendlyRole = string(*m) + } + return friendlyRole +} diff --git a/pkg/lmcli/model/tool.go b/pkg/lmcli/model/tool.go new file mode 100644 index 0000000..8b5ddea --- /dev/null +++ b/pkg/lmcli/model/tool.go @@ -0,0 +1,98 @@ +package model + +import ( + "database/sql/driver" + "encoding/json" + "fmt" +) + +type Tool struct { + Name string + Description string + Parameters []ToolParameter + Impl func(*Tool, map[string]interface{}) (string, error) +} + +type ToolParameter struct { + Name string `json:"name"` + Type string `json:"type"` // "string", "integer", "boolean" + Required bool `json:"required"` + Description string `json:"description"` + Enum []string `json:"enum,omitempty"` +} + +type ToolCall struct { + ID string `json:"id"` + Name string `json:"name"` + Parameters map[string]interface{} `json:"parameters"` +} + +type ToolCalls []ToolCall + +func (tc *ToolCalls) Scan(value any) (err error) { + s := value.(string) + if value == nil || s == "" { + *tc = nil + return + } + err = json.Unmarshal([]byte(s), tc) + return +} + +func (tc ToolCalls) Value() (driver.Value, error) { + if len(tc) == 0 { + return "", nil + } + jsonBytes, err := json.Marshal(tc) + if err != nil { + return "", fmt.Errorf("Could not marshal ToolCalls to JSON: %v\n", err) + } + return string(jsonBytes), nil +} + +type ToolResult struct { + ToolCallID string `json:"toolCallID"` + ToolName string `json:"toolName,omitempty"` + Result string `json:"result,omitempty"` +} + +type ToolResults []ToolResult + +func (tr *ToolResults) Scan(value any) (err error) { + s := value.(string) + if value == nil || s == "" { + *tr = nil + return + } + err = json.Unmarshal([]byte(s), tr) + return +} + +func (tr ToolResults) Value() (driver.Value, error) { + if len(tr) == 0 { + return "", nil + } + jsonBytes, err := json.Marshal([]ToolResult(tr)) + if err != nil { + return "", fmt.Errorf("Could not marshal ToolResults to JSON: %v\n", err) + } + return string(jsonBytes), nil +} + +type CallResult struct { + Message string `json:"message"` + Result any `json:"result,omitempty"` +} + +func (r CallResult) ToJson() (string, error) { + if r.Message == "" { + // When message not supplied, assume success + r.Message = "success" + } + + jsonBytes, err := json.Marshal(r) + if err != nil { + return "", fmt.Errorf("Could not marshal CallResult to JSON: %v\n", err) + } + return string(jsonBytes), nil +} diff --git a/pkg/lmcli/provider/anthropic/anthropic.go b/pkg/lmcli/provider/anthropic/anthropic.go new file mode 100644 index 0000000..9761bd0 --- /dev/null +++ b/pkg/lmcli/provider/anthropic/anthropic.go @@ -0,0 +1,322 @@ +package anthropic + +import ( + "bufio" + "bytes" + "encoding/json" + "encoding/xml" + "fmt" + "net/http" + "strings" + + "git.mlow.ca/mlow/lmcli/pkg/lmcli/tools" + "git.mlow.ca/mlow/lmcli/pkg/lmcli/model" +) + +type AnthropicClient struct { + APIKey string +} + +type Message struct { + Role string `json:"role"` + OriginalContent string `json:"content"` +} + +type Request struct { + Model string `json:"model"` + Messages []Message `json:"messages"` + System string `json:"system,omitempty"` + MaxTokens int `json:"max_tokens,omitempty"` + StopSequences []string `json:"stop_sequences,omitempty"` + Stream bool `json:"stream,omitempty"` + Temperature float32 `json:"temperature,omitempty"` + //TopP float32 `json:"top_p,omitempty"` + //TopK float32 `json:"top_k,omitempty"` +} + +type OriginalContent struct { + Type string `json:"type"` + Text string `json:"text"` +} + +type Response struct { + Id string `json:"id"` + Type string `json:"type"` + Role string `json:"role"` + OriginalContent []OriginalContent `json:"content"` +} + +const FUNCTION_STOP_SEQUENCE = "" + +func buildRequest(params model.RequestParameters, messages []model.Message) Request { + requestBody := Request{ + Model: params.Model, + Messages: make([]Message, len(messages)), + System: params.SystemPrompt, + MaxTokens: params.MaxTokens, + Temperature: params.Temperature, + Stream: false, + + StopSequences: []string{ + FUNCTION_STOP_SEQUENCE, + "\n\nHuman:", + }, + } + + startIdx := 0 + if messages[0].Role == model.MessageRoleSystem { + requestBody.System = messages[0].Content + requestBody.Messages = requestBody.Messages[:len(messages)-1] + startIdx = 1 + } + + if len(params.ToolBag) > 0 { + if len(requestBody.System) > 0 { + // add a divider between existing system prompt and tools + requestBody.System += "\n\n---\n\n" + } + requestBody.System += buildToolsSystemPrompt(params.ToolBag) + } + + for i, msg := range messages[startIdx:] { + message := &requestBody.Messages[i] + + switch msg.Role { + case model.MessageRoleToolCall: + message.Role = "assistant" + message.OriginalContent = msg.Content + //message.ToolCalls = convertToolCallToOpenAI(m.ToolCalls) + case model.MessageRoleToolResult: + xmlFuncResults := convertToolResultsToXMLFunctionResult(msg.ToolResults) + xmlString, err := xmlFuncResults.XMLString() + if err != nil { + panic("Could not serialize []ToolResult to XMLFunctionResults") + } + message.Role = "user" + message.OriginalContent = xmlString + default: + message.Role = string(msg.Role) + message.OriginalContent = msg.Content + } + } + return requestBody +} + +func sendRequest(c *AnthropicClient, r Request) (*http.Response, error) { + url := "https://api.anthropic.com/v1/messages" + + jsonBody, err := json.Marshal(r) + if err != nil { + return nil, fmt.Errorf("failed to marshal request body: %v", err) + } + + req, err := http.NewRequest("POST", url, bytes.NewBuffer(jsonBody)) + if err != nil { + return nil, fmt.Errorf("failed to create HTTP request: %v", err) + } + + req.Header.Set("x-api-key", c.APIKey) + req.Header.Set("anthropic-version", "2023-06-01") + req.Header.Set("content-type", "application/json") + + client := &http.Client{} + resp, err := client.Do(req) + if err != nil { + return nil, fmt.Errorf("failed to send HTTP request: %v", err) + } + + return resp, nil +} + +func (c *AnthropicClient) CreateChatCompletion( + params model.RequestParameters, + messages []model.Message, + replies *[]model.Message, +) (string, error) { + request := buildRequest(params, messages) + + resp, err := sendRequest(c, request) + if err != nil { + return "", err + } + defer resp.Body.Close() + + var response Response + err = json.NewDecoder(resp.Body).Decode(&response) + if err != nil { + return "", fmt.Errorf("failed to decode response: %v", err) + } + + sb := strings.Builder{} + for _, content := range response.OriginalContent { + var reply model.Message + switch content.Type { + case "text": + reply = model.Message{ + Role: model.MessageRoleAssistant, + Content: content.Text, + } + sb.WriteString(reply.Content) + default: + return "", fmt.Errorf("unsupported message type: %s", content.Type) + } + *replies = append(*replies, reply) + } + + return sb.String(), nil +} + +func (c *AnthropicClient) CreateChatCompletionStream( + params model.RequestParameters, + messages []model.Message, + replies *[]model.Message, + output chan<- string, +) (string, error) { + request := buildRequest(params, messages) + request.Stream = true + + resp, err := sendRequest(c, request) + if err != nil { + return "", err + } + defer resp.Body.Close() + + scanner := bufio.NewScanner(resp.Body) + sb := strings.Builder{} + + for scanner.Scan() { + line := scanner.Text() + line = strings.TrimSpace(line) + + if len(line) == 0 { + continue + } + + if line[0] == '{' { + var event map[string]interface{} + err := json.Unmarshal([]byte(line), &event) + if err != nil { + return "", fmt.Errorf("failed to unmarshal event data '%s': %v", line, err) + } + eventType, ok := event["type"].(string) + if !ok { + return "", fmt.Errorf("invalid event: %s", line) + } + switch eventType { + case "error": + return sb.String(), fmt.Errorf("an error occurred: %s", event["error"]) + default: + return sb.String(), fmt.Errorf("unknown event type: %s", eventType) + } + } else if strings.HasPrefix(line, "data:") { + data := strings.TrimSpace(strings.TrimPrefix(line, "data:")) + var event map[string]interface{} + err := json.Unmarshal([]byte(data), &event) + if err != nil { + return "", fmt.Errorf("failed to unmarshal event data: %v", err) + } + + eventType, ok := event["type"].(string) + if !ok { + return "", fmt.Errorf("invalid event type") + } + + switch eventType { + case "message_start": + // noop + case "ping": + // write an empty string to signal start of text + output <- "" + case "content_block_start": + // ignore? + case "content_block_delta": + delta, ok := event["delta"].(map[string]interface{}) + if !ok { + return "", fmt.Errorf("invalid content block delta") + } + text, ok := delta["text"].(string) + if !ok { + return "", fmt.Errorf("invalid text delta") + } + sb.WriteString(text) + output <- text + case "content_block_stop": + // ignore? + case "message_delta": + delta, ok := event["delta"].(map[string]interface{}) + if !ok { + return "", fmt.Errorf("invalid message delta") + } + stopReason, ok := delta["stop_reason"].(string) + if ok && stopReason == "stop_sequence" { + stopSequence, ok := delta["stop_sequence"].(string) + if ok && stopSequence == FUNCTION_STOP_SEQUENCE { + content := sb.String() + + start := strings.Index(content, "") + if start == -1 { + return content, fmt.Errorf("reached stop sequence but no opening tag found") + } + + funcCallXml := content[start:] + funcCallXml += FUNCTION_STOP_SEQUENCE + + sb.WriteString(FUNCTION_STOP_SEQUENCE) + output <- FUNCTION_STOP_SEQUENCE + + // Extract function calls + var functionCalls XMLFunctionCalls + err := xml.Unmarshal([]byte(sb.String()), &functionCalls) + if err != nil { + return "", fmt.Errorf("failed to unmarshal function_calls: %v", err) + } + + // Execute function calls + toolCall := model.Message{ + Role: model.MessageRoleToolCall, + Content: content, + ToolCalls: convertXMLFunctionCallsToToolCalls(functionCalls), + } + + toolResults, err := tools.ExecuteToolCalls(toolCall.ToolCalls, params.ToolBag) + if err != nil { + return "", err + } + + toolReply := model.Message{ + Role: model.MessageRoleToolResult, + ToolResults: toolResults, + } + + if replies != nil { + *replies = append(append(*replies, toolCall), toolReply) + } + + // Recurse into CreateChatCompletionStream with the tool call replies + // added to the original messages + messages = append(append(messages, toolCall), toolReply) + return c.CreateChatCompletionStream(params, messages, replies, output) + } + } + case "message_stop": + // return the completed message + reply := model.Message{ + Role: model.MessageRoleAssistant, + Content: sb.String(), + } + *replies = append(*replies, reply) + return sb.String(), nil + case "error": + return sb.String(), fmt.Errorf("an error occurred: %s", event["error"]) + default: + fmt.Printf("\nUnrecognized event: %s\n", data) + } + } + } + + if err := scanner.Err(); err != nil { + return "", fmt.Errorf("failed to read response body: %v", err) + } + + return "", fmt.Errorf("unexpected end of stream") +} diff --git a/pkg/lmcli/provider/anthropic/tools.go b/pkg/lmcli/provider/anthropic/tools.go new file mode 100644 index 0000000..de89d4a --- /dev/null +++ b/pkg/lmcli/provider/anthropic/tools.go @@ -0,0 +1,182 @@ +package anthropic + +import ( + "bytes" + "strings" + "text/template" + + "git.mlow.ca/mlow/lmcli/pkg/lmcli/model" +) + +const TOOL_PREAMBLE = `In this environment you have access to a set of tools which may assist you in fulfilling user requests. + +You may call them like this: + + +$TOOL_NAME + +<$PARAMETER_NAME>$PARAMETER_VALUE +... + + + + +Here are the tools available:` + +type XMLTools struct { + XMLName struct{} `xml:"tools"` + ToolDescriptions []XMLToolDescription `xml:"tool_description"` +} + +type XMLToolDescription struct { + ToolName string `xml:"tool_name"` + Description string `xml:"description"` + Parameters []XMLToolParameter `xml:"parameters>parameter"` +} + +type XMLToolParameter struct { + Name string `xml:"name"` + Type string `xml:"type"` + Description string `xml:"description"` +} + +type XMLFunctionCalls struct { + XMLName struct{} `xml:"function_calls"` + Invoke []XMLFunctionInvoke `xml:"invoke"` +} + +type XMLFunctionInvoke struct { + ToolName string `xml:"tool_name"` + Parameters XMLFunctionInvokeParameters `xml:"parameters"` +} + +type XMLFunctionInvokeParameters struct { + String string `xml:",innerxml"` +} + +type XMLFunctionResults struct { + XMLName struct{} `xml:"function_results"` + Result []XMLFunctionResult `xml:"result"` +} + +type XMLFunctionResult struct { + ToolName string `xml:"tool_name"` + Stdout string `xml:"stdout"` +} + +// accepts raw XML from XMLFunctionInvokeParameters.String, returns map of +// parameters name to value +func parseFunctionParametersXML(params string) map[string]interface{} { + lines := strings.Split(params, "\n") + ret := make(map[string]interface{}, len(lines)) + for _, line := range lines { + i := strings.Index(line, ">") + if i == -1 { + continue + } + j := strings.Index(line, " to get parameter name, + // then chop after > to first +{{range .ToolDescriptions}} +{{.ToolName}} + +{{.Description}} + + +{{range .Parameters}} +{{.Name}} +{{.Type}} +{{.Description}} + +{{end}} + +{{end}}`) + if err != nil { + return "", err + } + + var buf bytes.Buffer + if err := tmpl.Execute(&buf, x); err != nil { + return "", err + } + + return buf.String(), nil +} + +func (x XMLFunctionResults) XMLString() (string, error) { + tmpl, err := template.New("function_results").Parse(` +{{range .Result}} +{{.ToolName}} +{{.Stdout}} + +{{end}}`) + if err != nil { + return "", err + } + + var buf bytes.Buffer + if err := tmpl.Execute(&buf, x); err != nil { + return "", err + } + + return buf.String(), nil +} diff --git a/pkg/lmcli/provider/openai/openai.go b/pkg/lmcli/provider/openai/openai.go new file mode 100644 index 0000000..b21791d --- /dev/null +++ b/pkg/lmcli/provider/openai/openai.go @@ -0,0 +1,270 @@ +package openai + +import ( + "context" + "encoding/json" + "errors" + "fmt" + "io" + "strings" + + "git.mlow.ca/mlow/lmcli/pkg/lmcli/model" + "git.mlow.ca/mlow/lmcli/pkg/lmcli/tools" + openai "github.com/sashabaranov/go-openai" +) + +type OpenAIClient struct { + APIKey string +} + +type OpenAIToolParameters struct { + Type string `json:"type"` + Properties map[string]OpenAIToolParameter `json:"properties,omitempty"` + Required []string `json:"required,omitempty"` +} + +type OpenAIToolParameter struct { + Type string `json:"type"` + Description string `json:"description"` + Enum []string `json:"enum,omitempty"` +} + +func convertTools(tools []model.Tool) []openai.Tool { + openaiTools := make([]openai.Tool, len(tools)) + for i, tool := range tools { + openaiTools[i].Type = "function" + + params := make(map[string]OpenAIToolParameter) + var required []string + + for _, param := range tool.Parameters { + params[param.Name] = OpenAIToolParameter{ + Type: param.Type, + Description: param.Description, + Enum: param.Enum, + } + if param.Required { + required = append(required, param.Name) + } + } + + openaiTools[i].Function = openai.FunctionDefinition{ + Name: tool.Name, + Description: tool.Description, + Parameters: OpenAIToolParameters{ + Type: "object", + Properties: params, + Required: required, + }, + } + } + return openaiTools +} + +func convertToolCallToOpenAI(toolCalls []model.ToolCall) []openai.ToolCall { + converted := make([]openai.ToolCall, len(toolCalls)) + for i, call := range toolCalls { + converted[i].Type = "function" + converted[i].ID = call.ID + converted[i].Function.Name = call.Name + + json, _ := json.Marshal(call.Parameters) + converted[i].Function.Arguments = string(json) + } + return converted +} + +func convertToolCallToAPI(toolCalls []openai.ToolCall) []model.ToolCall { + converted := make([]model.ToolCall, len(toolCalls)) + for i, call := range toolCalls { + converted[i].ID = call.ID + converted[i].Name = call.Function.Name + json.Unmarshal([]byte(call.Function.Arguments), &converted[i].Parameters) + } + return converted +} + +func createChatCompletionRequest( + c *OpenAIClient, + params model.RequestParameters, + messages []model.Message, +) openai.ChatCompletionRequest { + requestMessages := make([]openai.ChatCompletionMessage, 0, len(messages)) + + for _, m := range messages { + switch m.Role { + case "tool_call": + message := openai.ChatCompletionMessage{} + message.Role = "assistant" + message.Content = m.Content + message.ToolCalls = convertToolCallToOpenAI(m.ToolCalls) + requestMessages = append(requestMessages, message) + case "tool_result": + // expand tool_result messages' results into multiple openAI messages + for _, result := range m.ToolResults { + message := openai.ChatCompletionMessage{} + message.Role = "tool" + message.Content = result.Result + message.ToolCallID = result.ToolCallID + requestMessages = append(requestMessages, message) + } + default: + message := openai.ChatCompletionMessage{} + message.Role = string(m.Role) + message.Content = m.Content + requestMessages = append(requestMessages, message) + } + } + + request := openai.ChatCompletionRequest{ + Model: params.Model, + MaxTokens: params.MaxTokens, + Temperature: params.Temperature, + Messages: requestMessages, + N: 1, // limit responses to 1 "choice". we use choices[0] to reference it + } + + if len(params.ToolBag) > 0 { + request.Tools = convertTools(params.ToolBag) + request.ToolChoice = "auto" + } + + return request +} + +func handleToolCalls( + params model.RequestParameters, + content string, + toolCalls []openai.ToolCall, +) ([]model.Message, error) { + toolCall := model.Message{ + Role: model.MessageRoleToolCall, + Content: content, + ToolCalls: convertToolCallToAPI(toolCalls), + } + + toolResults, err := tools.ExecuteToolCalls(toolCall.ToolCalls, params.ToolBag) + if err != nil { + return nil, err + } + + toolResult := model.Message{ + Role: model.MessageRoleToolResult, + ToolResults: toolResults, + } + + return []model.Message{toolCall, toolResult}, nil +} + +func (c *OpenAIClient) CreateChatCompletion( + params model.RequestParameters, + messages []model.Message, + replies *[]model.Message, +) (string, error) { + client := openai.NewClient(c.APIKey) + req := createChatCompletionRequest(c, params, messages) + resp, err := client.CreateChatCompletion(context.Background(), req) + if err != nil { + return "", err + } + + choice := resp.Choices[0] + + toolCalls := choice.Message.ToolCalls + if len(toolCalls) > 0 { + results, err := handleToolCalls(params, choice.Message.Content, toolCalls) + if err != nil { + return "", err + } + if results != nil { + *replies = append(*replies, results...) + } + + // Recurse into CreateChatCompletion with the tool call replies + messages = append(messages, results...) + return c.CreateChatCompletion(params, messages, replies) + } + + if replies != nil { + *replies = append(*replies, model.Message{ + Role: model.MessageRoleAssistant, + Content: choice.Message.Content, + }) + } + + // Return the user-facing message. + return choice.Message.Content, nil +} + +func (c *OpenAIClient) CreateChatCompletionStream( + params model.RequestParameters, + messages []model.Message, + replies *[]model.Message, + output chan<- string, +) (string, error) { + client := openai.NewClient(c.APIKey) + req := createChatCompletionRequest(c, params, messages) + + stream, err := client.CreateChatCompletionStream(context.Background(), req) + if err != nil { + return "", err + } + defer stream.Close() + + content := strings.Builder{} + toolCalls := []openai.ToolCall{} + + // Iterate stream segments + for { + response, e := stream.Recv() + if errors.Is(e, io.EOF) { + break + } + + if e != nil { + err = e + break + } + + delta := response.Choices[0].Delta + if len(delta.ToolCalls) > 0 { + // Construct streamed tool_call arguments + for _, tc := range delta.ToolCalls { + if tc.Index == nil { + return "", fmt.Errorf("Unexpected nil index for streamed tool call.") + } + if len(toolCalls) <= *tc.Index { + toolCalls = append(toolCalls, tc) + } else { + toolCalls[*tc.Index].Function.Arguments += tc.Function.Arguments + } + } + } else { + output <- delta.Content + content.WriteString(delta.Content) + } + } + + if len(toolCalls) > 0 { + results, err := handleToolCalls(params, content.String(), toolCalls) + if err != nil { + return content.String(), err + } + if results != nil { + *replies = append(*replies, results...) + } + + // Recurse into CreateChatCompletionStream with the tool call replies + messages = append(messages, results...) + return c.CreateChatCompletionStream(params, messages, replies, output) + } + + if replies != nil { + *replies = append(*replies, model.Message{ + Role: model.MessageRoleAssistant, + Content: content.String(), + }) + } + + return content.String(), err +} diff --git a/pkg/lmcli/provider/provider.go b/pkg/lmcli/provider/provider.go new file mode 100644 index 0000000..f0e4ace --- /dev/null +++ b/pkg/lmcli/provider/provider.go @@ -0,0 +1,23 @@ +package provider + +import "git.mlow.ca/mlow/lmcli/pkg/lmcli/model" + +type ChatCompletionClient interface { + // CreateChatCompletion requests a response to the provided messages. + // Replies are appended to the given replies struct, and the + // complete user-facing response is returned as a string. + CreateChatCompletion( + params model.RequestParameters, + messages []model.Message, + replies *[]model.Message, + ) (string, error) + + // Like CreateChageCompletion, except the response is streamed via + // the output channel as it's received. + CreateChatCompletionStream( + params model.RequestParameters, + messages []model.Message, + replies *[]model.Message, + output chan<- string, + ) (string, error) +} diff --git a/pkg/lmcli/store.go b/pkg/lmcli/store.go new file mode 100644 index 0000000..f0b3b98 --- /dev/null +++ b/pkg/lmcli/store.go @@ -0,0 +1,121 @@ +package lmcli + +import ( + "database/sql" + "errors" + "fmt" + "strings" + + "git.mlow.ca/mlow/lmcli/pkg/lmcli/model" + sqids "github.com/sqids/sqids-go" + "gorm.io/gorm" +) + +type ConversationStore interface { + Conversations() ([]model.Conversation, error) + + ConversationByShortName(shortName string) (*model.Conversation, error) + ConversationShortNameCompletions(search string) []string + + SaveConversation(conversation *model.Conversation) error + DeleteConversation(conversation *model.Conversation) error + + Messages(conversation *model.Conversation) ([]model.Message, error) + LastMessage(conversation *model.Conversation) (*model.Message, error) + + SaveMessage(message *model.Message) error + DeleteMessage(message *model.Message) error + UpdateMessage(message *model.Message) error +} + +type SQLStore struct { + db *gorm.DB + sqids *sqids.Sqids +} + +func NewSQLStore(db *gorm.DB) (*SQLStore, error) { + models := []any{ + &model.Conversation{}, + &model.Message{}, + } + + for _, x := range models { + err := db.AutoMigrate(x) + if err != nil { + return nil, fmt.Errorf("Could not perform database migrations: %v", err) + } + } + + _sqids, _ := sqids.New(sqids.Options{MinLength: 4}) + return &SQLStore{db, _sqids}, nil +} + +func (s *SQLStore) SaveConversation(conversation *model.Conversation) error { + err := s.db.Save(&conversation).Error + if err != nil { + return err + } + + if !conversation.ShortName.Valid { + shortName, _ := s.sqids.Encode([]uint64{uint64(conversation.ID)}) + conversation.ShortName = sql.NullString{String: shortName, Valid: true} + err = s.db.Save(&conversation).Error + } + + return err +} + +func (s *SQLStore) DeleteConversation(conversation *model.Conversation) error { + s.db.Where("conversation_id = ?", conversation.ID).Delete(&model.Message{}) + return s.db.Delete(&conversation).Error +} + +func (s *SQLStore) SaveMessage(message *model.Message) error { + return s.db.Create(message).Error +} + +func (s *SQLStore) DeleteMessage(message *model.Message) error { + return s.db.Delete(&message).Error +} + +func (s *SQLStore) UpdateMessage(message *model.Message) error { + return s.db.Updates(&message).Error +} + +func (s *SQLStore) Conversations() ([]model.Conversation, error) { + var conversations []model.Conversation + err := s.db.Find(&conversations).Error + return conversations, err +} + +func (s *SQLStore) ConversationShortNameCompletions(shortName string) []string { + var completions []string + conversations, _ := s.Conversations() // ignore error for completions + for _, conversation := range conversations { + if shortName == "" || strings.HasPrefix(conversation.ShortName.String, shortName) { + completions = append(completions, fmt.Sprintf("%s\t%s", conversation.ShortName.String, conversation.Title)) + } + } + return completions +} + +func (s *SQLStore) ConversationByShortName(shortName string) (*model.Conversation, error) { + if shortName == "" { + return nil, errors.New("shortName is empty") + } + var conversation model.Conversation + err := s.db.Where("short_name = ?", shortName).Find(&conversation).Error + return &conversation, err +} + +func (s *SQLStore) Messages(conversation *model.Conversation) ([]model.Message, error) { + var messages []model.Message + err := s.db.Where("conversation_id = ?", conversation.ID).Find(&messages).Error + return messages, err +} + +func (s *SQLStore) LastMessage(conversation *model.Conversation) (*model.Message, error) { + var message model.Message + err := s.db.Where("conversation_id = ?", conversation.ID).Last(&message).Error + return &message, err +} diff --git a/pkg/lmcli/tools/file_insert_lines.go b/pkg/lmcli/tools/file_insert_lines.go new file mode 100644 index 0000000..513f9a5 --- /dev/null +++ b/pkg/lmcli/tools/file_insert_lines.go @@ -0,0 +1,114 @@ +package tools + +import ( + "fmt" + "os" + "strings" + + toolutil "git.mlow.ca/mlow/lmcli/pkg/lmcli/tools/util" + "git.mlow.ca/mlow/lmcli/pkg/lmcli/model" +) + +const FILE_INSERT_LINES_DESCRIPTION = `Insert lines into a file, must specify path. + +Make sure your inserts match the flow and indentation of surrounding content.` + +var FileInsertLinesTool = model.Tool{ + Name: "file_insert_lines", + Description: FILE_INSERT_LINES_DESCRIPTION, + Parameters: []model.ToolParameter{ + { + Name: "path", + Type: "string", + Description: "Path of the file to be modified, relative to the current working directory.", + Required: true, + }, + { + Name: "position", + Type: "integer", + Description: `Which line to insert content *before*.`, + Required: true, + }, + { + Name: "content", + Type: "string", + Description: `The content to insert.`, + Required: true, + }, + }, + Impl: func(tool *model.Tool, args map[string]interface{}) (string, error) { + tmp, ok := args["path"] + if !ok { + return "", fmt.Errorf("path parameter to write_file was not included.") + } + path, ok := tmp.(string) + if !ok { + return "", fmt.Errorf("Invalid path in function arguments: %v", tmp) + } + var position int + tmp, ok = args["position"] + if ok { + tmp, ok := tmp.(float64) + if !ok { + return "", fmt.Errorf("Invalid position in function arguments: %v", tmp) + } + position = int(tmp) + } + var content string + tmp, ok = args["content"] + if ok { + content, ok = tmp.(string) + if !ok { + return "", fmt.Errorf("Invalid content in function arguments: %v", tmp) + } + } + + result := fileInsertLines(path, position, content) + ret, err := result.ToJson() + if err != nil { + return "", fmt.Errorf("Could not serialize result: %v", err) + } + return ret, nil + }, +} + +func fileInsertLines(path string, position int, content string) model.CallResult { + ok, reason := toolutil.IsPathWithinCWD(path) + if !ok { + return model.CallResult{Message: reason} + } + + // Read the existing file's content + data, err := os.ReadFile(path) + if err != nil { + if !os.IsNotExist(err) { + return model.CallResult{Message: fmt.Sprintf("Could not read path: %s", err.Error())} + } + _, err = os.Create(path) + if err != nil { + return model.CallResult{Message: fmt.Sprintf("Could not create new file: %s", err.Error())} + } + data = []byte{} + } + + if position < 1 { + return model.CallResult{Message: "start_line cannot be less than 1"} + } + + lines := strings.Split(string(data), "\n") + contentLines := strings.Split(strings.Trim(content, "\n"), "\n") + + before := lines[:position-1] + after := lines[position-1:] + lines = append(before, append(contentLines, after...)...) + + newContent := strings.Join(lines, "\n") + + // Join the lines and write back to the file + err = os.WriteFile(path, []byte(newContent), 0644) + if err != nil { + return model.CallResult{Message: fmt.Sprintf("Could not write to path: %s", err.Error())} + } + + return model.CallResult{Result: newContent} +} diff --git a/pkg/lmcli/tools/file_replace_lines.go b/pkg/lmcli/tools/file_replace_lines.go new file mode 100644 index 0000000..cdb1def --- /dev/null +++ b/pkg/lmcli/tools/file_replace_lines.go @@ -0,0 +1,133 @@ +package tools + +import ( + "fmt" + "os" + "strings" + + "git.mlow.ca/mlow/lmcli/pkg/lmcli/model" + toolutil "git.mlow.ca/mlow/lmcli/pkg/lmcli/tools/util" +) + +const FILE_REPLACE_LINES_DESCRIPTION = `Replace or remove a range of lines within a file, must specify path. + +Useful for re-writing snippets/blocks of code or entire functions. + +Plan your edits carefully and ensure any new content matches the flow and indentation of surrounding text.` + +var FileReplaceLinesTool = model.Tool{ + Name: "file_replace_lines", + Description: FILE_REPLACE_LINES_DESCRIPTION, + Parameters: []model.ToolParameter{ + { + Name: "path", + Type: "string", + Description: "Path of the file to be modified, relative to the current working directory.", + Required: true, + }, + { + Name: "start_line", + Type: "integer", + Description: `Line number which specifies the start of the replacement range (inclusive).`, + Required: true, + }, + { + Name: "end_line", + Type: "integer", + Description: `Line number which specifies the end of the replacement range (inclusive). If unset, range extends to end of file.`, + }, + { + Name: "content", + Type: "string", + Description: `Content to replace specified range. Omit to remove the specified range.`, + }, + }, + Impl: func(tool *model.Tool, args map[string]interface{}) (string, error) { + tmp, ok := args["path"] + if !ok { + return "", fmt.Errorf("path parameter to write_file was not included.") + } + path, ok := tmp.(string) + if !ok { + return "", fmt.Errorf("Invalid path in function arguments: %v", tmp) + } + var start_line int + tmp, ok = args["start_line"] + if ok { + tmp, ok := tmp.(float64) + if !ok { + return "", fmt.Errorf("Invalid start_line in function arguments: %v", tmp) + } + start_line = int(tmp) + } + var end_line int + tmp, ok = args["end_line"] + if ok { + tmp, ok := tmp.(float64) + if !ok { + return "", fmt.Errorf("Invalid end_line in function arguments: %v", tmp) + } + end_line = int(tmp) + } + var content string + tmp, ok = args["content"] + if ok { + content, ok = tmp.(string) + if !ok { + return "", fmt.Errorf("Invalid content in function arguments: %v", tmp) + } + } + + result := fileReplaceLines(path, start_line, end_line, content) + ret, err := result.ToJson() + if err != nil { + return "", fmt.Errorf("Could not serialize result: %v", err) + } + return ret, nil + }, +} + +func fileReplaceLines(path string, startLine int, endLine int, content string) model.CallResult { + ok, reason := toolutil.IsPathWithinCWD(path) + if !ok { + return model.CallResult{Message: reason} + } + + // Read the existing file's content + data, err := os.ReadFile(path) + if err != nil { + if !os.IsNotExist(err) { + return model.CallResult{Message: fmt.Sprintf("Could not read path: %s", err.Error())} + } + _, err = os.Create(path) + if err != nil { + return model.CallResult{Message: fmt.Sprintf("Could not create new file: %s", err.Error())} + } + data = []byte{} + } + + if startLine < 1 { + return model.CallResult{Message: "start_line cannot be less than 1"} + } + + lines := strings.Split(string(data), "\n") + contentLines := strings.Split(strings.Trim(content, "\n"), "\n") + + if endLine == 0 || endLine > len(lines) { + endLine = len(lines) + } + + before := lines[:startLine-1] + after := lines[endLine:] + + lines = append(before, append(contentLines, after...)...) + newContent := strings.Join(lines, "\n") + + // Join the lines and write back to the file + err = os.WriteFile(path, []byte(newContent), 0644) + if err != nil { + return model.CallResult{Message: fmt.Sprintf("Could not write to path: %s", err.Error())} + } + + return model.CallResult{Result: newContent} +} diff --git a/pkg/lmcli/tools/read_dir.go b/pkg/lmcli/tools/read_dir.go new file mode 100644 index 0000000..46534e4 --- /dev/null +++ b/pkg/lmcli/tools/read_dir.go @@ -0,0 +1,100 @@ +package tools + +import ( + "fmt" + "os" + "path/filepath" + "strings" + + "git.mlow.ca/mlow/lmcli/pkg/lmcli/model" + toolutil "git.mlow.ca/mlow/lmcli/pkg/lmcli/tools/util" +) + +const READ_DIR_DESCRIPTION = `Return the contents of the CWD (current working directory). + +Example result: +{ + "message": "success", + "result": [ + {"name": "a_file.txt", "type": "file", "size": 123}, + {"name": "a_directory/", "type": "dir", "size": 11}, + ... + ] +} + +For files, size represents the size of the file, in bytes. +For directories, size represents the number of entries in that directory.` + +var ReadDirTool = model.Tool{ + Name: "read_dir", + Description: READ_DIR_DESCRIPTION, + Parameters: []model.ToolParameter{ + { + Name: "relative_dir", + Type: "string", + Description: "If set, read the contents of a directory relative to the current one.", + }, + }, + Impl: func(tool *model.Tool, args map[string]interface{}) (string, error) { + var relativeDir string + tmp, ok := args["relative_dir"] + if ok { + relativeDir, ok = tmp.(string) + if !ok { + return "", fmt.Errorf("Invalid relative_dir in function arguments: %v", tmp) + } + } + result := readDir(relativeDir) + ret, err := result.ToJson() + if err != nil { + return "", fmt.Errorf("Could not serialize result: %v", err) + } + return ret, nil + }, +} + +func readDir(path string) model.CallResult { + if path == "" { + path = "." + } + ok, reason := toolutil.IsPathWithinCWD(path) + if !ok { + return model.CallResult{Message: reason} + } + + files, err := os.ReadDir(path) + if err != nil { + return model.CallResult{ + Message: err.Error(), + } + } + + var dirContents []map[string]interface{} + for _, f := range files { + info, _ := f.Info() + + name := f.Name() + if strings.HasPrefix(name, ".") { + // skip hidden files + continue + } + + entryType := "file" + size := info.Size() + + if info.IsDir() { + name += "/" + entryType = "dir" + subdirfiles, _ := os.ReadDir(filepath.Join(".", path, info.Name())) + size = int64(len(subdirfiles)) + } + + dirContents = append(dirContents, map[string]interface{}{ + "name": name, + "type": entryType, + "size": size, + }) + } + + return model.CallResult{Result: dirContents} +} diff --git a/pkg/lmcli/tools/read_file.go b/pkg/lmcli/tools/read_file.go new file mode 100644 index 0000000..2b59500 --- /dev/null +++ b/pkg/lmcli/tools/read_file.go @@ -0,0 +1,71 @@ +package tools + +import ( + "fmt" + "os" + "strings" + + "git.mlow.ca/mlow/lmcli/pkg/lmcli/model" + toolutil "git.mlow.ca/mlow/lmcli/pkg/lmcli/tools/util" +) + +const READ_FILE_DESCRIPTION = `Read the contents of a text file relative to the current working directory. + +Each line of the returned content is prefixed with its line number and a tab (\t). + +Example result: +{ + "message": "success", + "result": "1\tthe contents\n2\tof the file\n" +}` + +var ReadFileTool = model.Tool{ + Name: "read_file", + Description: READ_FILE_DESCRIPTION, + Parameters: []model.ToolParameter{ + { + Name: "path", + Type: "string", + Description: "Path to a file within the current working directory to read.", + Required: true, + }, + }, + + Impl: func(tool *model.Tool, args map[string]interface{}) (string, error) { + tmp, ok := args["path"] + if !ok { + return "", fmt.Errorf("Path parameter to read_file was not included.") + } + path, ok := tmp.(string) + if !ok { + return "", fmt.Errorf("Invalid path in function arguments: %v", tmp) + } + result := readFile(path) + ret, err := result.ToJson() + if err != nil { + return "", fmt.Errorf("Could not serialize result: %v", err) + } + return ret, nil + }, +} + +func readFile(path string) model.CallResult { + ok, reason := toolutil.IsPathWithinCWD(path) + if !ok { + return model.CallResult{Message: reason} + } + data, err := os.ReadFile(path) + if err != nil { + return model.CallResult{Message: fmt.Sprintf("Could not read path: %s", err.Error())} + } + + lines := strings.Split(string(data), "\n") + content := strings.Builder{} + for i, line := range lines { + content.WriteString(fmt.Sprintf("%d\t%s\n", i+1, line)) + } + + return model.CallResult{ + Result: content.String(), + } +} diff --git a/pkg/lmcli/tools/tools.go b/pkg/lmcli/tools/tools.go new file mode 100644 index 0000000..d940a89 --- /dev/null +++ b/pkg/lmcli/tools/tools.go @@ -0,0 +1,51 @@ +package tools + +import ( + "fmt" + "os" + + "git.mlow.ca/mlow/lmcli/pkg/lmcli/model" +) + +var AvailableTools map[string]model.Tool = map[string]model.Tool{ + "read_dir": ReadDirTool, + "read_file": ReadFileTool, + "write_file": WriteFileTool, + "file_insert_lines": FileInsertLinesTool, + "file_replace_lines": FileReplaceLinesTool, +} + +func ExecuteToolCalls(toolCalls []model.ToolCall, toolBag []model.Tool) ([]model.ToolResult, error) { + var toolResults []model.ToolResult + for _, toolCall := range toolCalls { + var tool *model.Tool + for _, available := range toolBag { + if available.Name == toolCall.Name { + tool = &available + break + } + } + if tool == nil { + return nil, fmt.Errorf("Requested tool '%s' does not exist. Hallucination?", toolCall.Name) + } + + // TODO: ability to silence this + fmt.Fprintf(os.Stderr, "\nINFO: Executing tool '%s' with args %s\n", toolCall.Name, toolCall.Parameters) + + // Execute the tool + result, err := tool.Impl(tool, toolCall.Parameters) + if err != nil { + // This can happen if the model missed or supplied invalid tool args + return nil, fmt.Errorf("Tool '%s' error: %v\n", toolCall.Name, err) + } + + toolResult := model.ToolResult{ + ToolCallID: toolCall.ID, + ToolName: toolCall.Name, + Result: result, + } + + toolResults = append(toolResults, toolResult) + } + return toolResults, nil +} diff --git a/pkg/lmcli/tools/util/util.go b/pkg/lmcli/tools/util/util.go new file mode 100644 index 0000000..eb8a8b7 --- /dev/null +++ b/pkg/lmcli/tools/util/util.go @@ -0,0 +1,67 @@ +package util + +import ( + "fmt" + "os" + "path/filepath" + "strings" +) + +// isPathContained attempts to verify whether `path` is the same as or +// contained within `directory`. It is overly cautious, returning false even if +// `path` IS contained within `directory`, but the two paths use different +// casing, and we happen to be on a case-insensitive filesystem. +// This is ultimately to attempt to stop an LLM from going outside of where I +// tell it to. Additional layers of security should be considered.. run in a +// VM/container. +func IsPathContained(directory string, path string) (bool, error) { + // Clean and resolve symlinks for both paths + path, err := filepath.Abs(path) + if err != nil { + return false, err + } + + // check if path exists + _, err = os.Stat(path) + if err != nil { + if !os.IsNotExist(err) { + return false, fmt.Errorf("Could not stat path: %v", err) + } + } else { + path, err = filepath.EvalSymlinks(path) + if err != nil { + return false, err + } + } + + directory, err = filepath.Abs(directory) + if err != nil { + return false, err + } + directory, err = filepath.EvalSymlinks(directory) + if err != nil { + return false, err + } + + // Case insensitive checks + if !strings.EqualFold(path, directory) && + !strings.HasPrefix(strings.ToLower(path), strings.ToLower(directory)+string(os.PathSeparator)) { + return false, nil + } + + return true, nil +} + +func IsPathWithinCWD(path string) (bool, string) { + cwd, err := os.Getwd() + if err != nil { + return false, "Failed to determine current working directory" + } + if ok, err := IsPathContained(cwd, path); !ok { + if err != nil { + return false, fmt.Sprintf("Could not determine whether path '%s' is within the current working directory: %s", path, err.Error()) + } + return false, fmt.Sprintf("Path '%s' is not within the current working directory", path) + } + return true, "" +} diff --git a/pkg/lmcli/tools/write_file.go b/pkg/lmcli/tools/write_file.go new file mode 100644 index 0000000..7263db5 --- /dev/null +++ b/pkg/lmcli/tools/write_file.go @@ -0,0 +1,71 @@ +package tools + +import ( + "fmt" + "os" + + "git.mlow.ca/mlow/lmcli/pkg/lmcli/model" + toolutil "git.mlow.ca/mlow/lmcli/pkg/lmcli/tools/util" +) + +const WRITE_FILE_DESCRIPTION = `Write the provided contents to a file relative to the current working directory. + +Example result: +{ + "message": "success" +}` + +var WriteFileTool = model.Tool{ + Name: "write_file", + Description: WRITE_FILE_DESCRIPTION, + Parameters: []model.ToolParameter{ + { + Name: "path", + Type: "string", + Description: "Path to a file within the current working directory to write to.", + Required: true, + }, + { + Name: "content", + Type: "string", + Description: "The content to write to the file. Overwrites any existing content!", + Required: true, + }, + }, + Impl: func(t *model.Tool, args map[string]interface{}) (string, error) { + tmp, ok := args["path"] + if !ok { + return "", fmt.Errorf("Path parameter to write_file was not included.") + } + path, ok := tmp.(string) + if !ok { + return "", fmt.Errorf("Invalid path in function arguments: %v", tmp) + } + tmp, ok = args["content"] + if !ok { + return "", fmt.Errorf("Content parameter to write_file was not included.") + } + content, ok := tmp.(string) + if !ok { + return "", fmt.Errorf("Invalid content in function arguments: %v", tmp) + } + result := writeFile(path, content) + ret, err := result.ToJson() + if err != nil { + return "", fmt.Errorf("Could not serialize result: %v", err) + } + return ret, nil + }, +} + +func writeFile(path string, content string) model.CallResult { + ok, reason := toolutil.IsPathWithinCWD(path) + if !ok { + return model.CallResult{Message: reason} + } + err := os.WriteFile(path, []byte(content), 0644) + if err != nil { + return model.CallResult{Message: fmt.Sprintf("Could not write to path: %s", err.Error())} + } + return model.CallResult{} +} diff --git a/pkg/util/tty/highlight.go b/pkg/util/tty/highlight.go new file mode 100644 index 0000000..2548b0d --- /dev/null +++ b/pkg/util/tty/highlight.go @@ -0,0 +1,60 @@ +package tty + +import ( + "io" + "strings" + + "github.com/alecthomas/chroma/v2" + "github.com/alecthomas/chroma/v2/formatters" + "github.com/alecthomas/chroma/v2/lexers" + "github.com/alecthomas/chroma/v2/styles" +) + +type ChromaHighlighter struct { + lexer chroma.Lexer + formatter chroma.Formatter + style *chroma.Style +} + +func NewChromaHighlighter(lang, format, style string) *ChromaHighlighter { + l := lexers.Get(lang) + if l == nil { + l = lexers.Fallback + } + l = chroma.Coalesce(l) + + f := formatters.Get(format) + if f == nil { + f = formatters.Fallback + } + + s := styles.Get(style) + if s == nil { + s = styles.Fallback + } + + return &ChromaHighlighter{ + lexer: l, + formatter: f, + style: s, + } +} + +func (s *ChromaHighlighter) Highlight(w io.Writer, text string) error { + it, err := s.lexer.Tokenise(nil, text) + if err != nil { + return err + } + return s.formatter.Format(w, s.style, it) +} + +func (s *ChromaHighlighter) HighlightS(text string) (string, error) { + it, err := s.lexer.Tokenise(nil, text) + if err != nil { + return "", err + } + sb := strings.Builder{} + sb.Grow(len(text) * 2) + s.formatter.Format(&sb, s.style, it) + return sb.String(), nil +} diff --git a/pkg/cli/util.go b/pkg/util/util.go similarity index 89% rename from pkg/cli/util.go rename to pkg/util/util.go index 1bd9c95..c11a055 100644 --- a/pkg/cli/util.go +++ b/pkg/util/util.go @@ -1,4 +1,4 @@ -package cli +package util import ( "fmt" @@ -56,7 +56,7 @@ func InputFromEditor(placeholder string, pattern string, content string) (string // humanTimeElapsedSince returns a human-friendly "in the past" representation // of the given duration. -func humanTimeElapsedSince(d time.Duration) string { +func HumanTimeElapsedSince(d time.Duration) string { seconds := d.Seconds() minutes := seconds / 60 hours := minutes / 60 @@ -151,6 +151,14 @@ func SetStructDefaults(data interface{}) bool { intValue, _ := strconv.ParseInt(defaultTag, 10, 64) field.Set(reflect.New(e)) field.Elem().SetInt(intValue) + case reflect.Float32: + floatValue, _ := strconv.ParseFloat(defaultTag, 32) + field.Set(reflect.New(e)) + field.Elem().SetFloat(floatValue) + case reflect.Float64: + floatValue, _ := strconv.ParseFloat(defaultTag, 64) + field.Set(reflect.New(e)) + field.Elem().SetFloat(floatValue) case reflect.Bool: boolValue := defaultTag == "true" field.Set(reflect.ValueOf(&boolValue)) @@ -160,10 +168,8 @@ func SetStructDefaults(data interface{}) bool { return changed } -// FileContents returns the string contents of the given file. -// TODO: we should support retrieving the content (or an approximation of) -// non-text documents, e.g. PDFs. -func FileContents(file string) (string, error) { +// ReadFileContents returns the string contents of the given file. +func ReadFileContents(file string) (string, error) { path := filepath.Clean(file) content, err := os.ReadFile(path) if err != nil {