Compare commits
No commits in common. "0a27b9a8d34530be58e9250e5b6e15b5ca20354a" and "fa966d30db5fbf6f5dbefd80434b8defe73de919" have entirely different histories.
0a27b9a8d3
...
fa966d30db
10
go.mod
10
go.mod
@ -4,8 +4,8 @@ go 1.21
|
||||
|
||||
require (
|
||||
github.com/alecthomas/chroma/v2 v2.11.1
|
||||
github.com/charmbracelet/lipgloss v0.10.0
|
||||
github.com/go-yaml/yaml v2.1.0+incompatible
|
||||
github.com/gookit/color v1.5.4
|
||||
github.com/sashabaranov/go-openai v1.17.7
|
||||
github.com/spf13/cobra v1.8.0
|
||||
github.com/sqids/sqids-go v0.4.1
|
||||
@ -14,20 +14,14 @@ require (
|
||||
)
|
||||
|
||||
require (
|
||||
github.com/aymanbagabas/go-osc52/v2 v2.0.1 // indirect
|
||||
github.com/dlclark/regexp2 v1.10.0 // indirect
|
||||
github.com/inconshreveable/mousetrap v1.1.0 // indirect
|
||||
github.com/jinzhu/inflection v1.0.0 // indirect
|
||||
github.com/jinzhu/now v1.1.5 // indirect
|
||||
github.com/kr/pretty v0.3.1 // indirect
|
||||
github.com/lucasb-eyer/go-colorful v1.2.0 // indirect
|
||||
github.com/mattn/go-isatty v0.0.18 // indirect
|
||||
github.com/mattn/go-runewidth v0.0.15 // indirect
|
||||
github.com/mattn/go-sqlite3 v1.14.18 // indirect
|
||||
github.com/muesli/reflow v0.3.0 // indirect
|
||||
github.com/muesli/termenv v0.15.2 // indirect
|
||||
github.com/rivo/uniseg v0.4.7 // indirect
|
||||
github.com/spf13/pflag v1.0.5 // indirect
|
||||
github.com/xo/terminfo v0.0.0-20210125001918-ca9a967f8778 // indirect
|
||||
golang.org/x/sys v0.14.0 // indirect
|
||||
gopkg.in/check.v1 v1.0.0-20190902080502-41f04d3bba15 // indirect
|
||||
gopkg.in/yaml.v2 v2.2.2 // indirect
|
||||
|
31
go.sum
31
go.sum
@ -4,16 +4,16 @@ github.com/alecthomas/chroma/v2 v2.11.1 h1:m9uUtgcdAwgfFNxuqj7AIG75jD2YmL61BBIJW
|
||||
github.com/alecthomas/chroma/v2 v2.11.1/go.mod h1:4TQu7gdfuPjSh76j78ietmqh9LiurGF0EpseFXdKMBw=
|
||||
github.com/alecthomas/repr v0.2.0 h1:HAzS41CIzNW5syS8Mf9UwXhNH1J9aix/BvDRf1Ml2Yk=
|
||||
github.com/alecthomas/repr v0.2.0/go.mod h1:Fr0507jx4eOXV7AlPV6AVZLYrLIuIeSOWtW57eE/O/4=
|
||||
github.com/aymanbagabas/go-osc52/v2 v2.0.1 h1:HwpRHbFMcZLEVr42D4p7XBqjyuxQH5SMiErDT4WkJ2k=
|
||||
github.com/aymanbagabas/go-osc52/v2 v2.0.1/go.mod h1:uYgXzlJ7ZpABp8OJ+exZzJJhRNQ2ASbcXHWsFqH8hp8=
|
||||
github.com/charmbracelet/lipgloss v0.10.0 h1:KWeXFSexGcfahHX+54URiZGkBFazf70JNMtwg/AFW3s=
|
||||
github.com/charmbracelet/lipgloss v0.10.0/go.mod h1:Wig9DSfvANsxqkRsqj6x87irdy123SR4dOXlKa91ciE=
|
||||
github.com/cpuguy83/go-md2man/v2 v2.0.3/go.mod h1:tgQtvFlXSQOSOSIRvRPT7W67SCa46tRHOmNcaadrF8o=
|
||||
github.com/creack/pty v1.1.9/go.mod h1:oKZEueFk5CKHvIhNR5MUki03XCEU+Q6VDXinZuGJ33E=
|
||||
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
|
||||
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
|
||||
github.com/dlclark/regexp2 v1.10.0 h1:+/GIL799phkJqYW+3YbOd8LCcbHzT0Pbo8zl70MHsq0=
|
||||
github.com/dlclark/regexp2 v1.10.0/go.mod h1:DHkYz0B9wPfa6wondMfaivmHpzrQ3v9q8cnmRbL6yW8=
|
||||
github.com/go-yaml/yaml v2.1.0+incompatible h1:RYi2hDdss1u4YE7GwixGzWwVo47T8UQwnTLB6vQiq+o=
|
||||
github.com/go-yaml/yaml v2.1.0+incompatible/go.mod h1:w2MrLa16VYP0jy6N7M5kHaCkaLENm+P+Tv+MfurjSw0=
|
||||
github.com/gookit/color v1.5.4 h1:FZmqs7XOyGgCAxmWyPslpiok1k05wmY3SJTytgvYFs0=
|
||||
github.com/gookit/color v1.5.4/go.mod h1:pZJOeOS8DM43rXbp4AZo1n9zCU2qjpcRko0b6/QJi9w=
|
||||
github.com/hexops/gotextdiff v1.0.3 h1:gitA9+qJrrTCsiCl7+kh75nPqQt1cx4ZkudSTLoUqJM=
|
||||
github.com/hexops/gotextdiff v1.0.3/go.mod h1:pSWU5MAI3yDq+fZBTazCSJysOMbxWL1BSow5/V2vxeg=
|
||||
github.com/inconshreveable/mousetrap v1.1.0 h1:wN+x4NVGpMsO7ErUn/mUI3vEoE6Jt13X2s0bqwp9tc8=
|
||||
@ -26,24 +26,11 @@ github.com/kr/pretty v0.3.1 h1:flRD4NNwYAUpkphVc1HcthR4KEIFJ65n8Mw5qdRn3LE=
|
||||
github.com/kr/pretty v0.3.1/go.mod h1:hoEshYVHaxMs3cyo3Yncou5ZscifuDolrwPKZanG3xk=
|
||||
github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY=
|
||||
github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE=
|
||||
github.com/lucasb-eyer/go-colorful v1.2.0 h1:1nnpGOrhyZZuNyfu1QjKiUICQ74+3FNCN69Aj6K7nkY=
|
||||
github.com/lucasb-eyer/go-colorful v1.2.0/go.mod h1:R4dSotOR9KMtayYi1e77YzuveK+i7ruzyGqttikkLy0=
|
||||
github.com/mattn/go-isatty v0.0.18 h1:DOKFKCQ7FNG2L1rbrmstDN4QVRdS89Nkh85u68Uwp98=
|
||||
github.com/mattn/go-isatty v0.0.18/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y=
|
||||
github.com/mattn/go-runewidth v0.0.12/go.mod h1:RAqKPSqVFrSLVXbA8x7dzmKdmGzieGRCM46jaSJTDAk=
|
||||
github.com/mattn/go-runewidth v0.0.15 h1:UNAjwbU9l54TA3KzvqLGxwWjHmMgBUVhBiTjelZgg3U=
|
||||
github.com/mattn/go-runewidth v0.0.15/go.mod h1:Jdepj2loyihRzMpdS35Xk/zdY8IAYHsh153qUoGf23w=
|
||||
github.com/mattn/go-sqlite3 v1.14.18 h1:JL0eqdCOq6DJVNPSvArO/bIV9/P7fbGrV00LZHc+5aI=
|
||||
github.com/mattn/go-sqlite3 v1.14.18/go.mod h1:2eHXhiwb8IkHr+BDWZGa96P6+rkvnG63S2DGjv9HUNg=
|
||||
github.com/muesli/reflow v0.3.0 h1:IFsN6K9NfGtjeggFP+68I4chLZV2yIKsXJFNZ+eWh6s=
|
||||
github.com/muesli/reflow v0.3.0/go.mod h1:pbwTDkVPibjO2kyvBQRBxTWEEGDGq0FlB1BIKtnHY/8=
|
||||
github.com/muesli/termenv v0.15.2 h1:GohcuySI0QmI3wN8Ok9PtKGkgkFIk7y6Vpb5PvrY+Wo=
|
||||
github.com/muesli/termenv v0.15.2/go.mod h1:Epx+iuz8sNs7mNKhxzH4fWXGNpZwUaJKRS1noLXviQ8=
|
||||
github.com/pkg/diff v0.0.0-20210226163009-20ebb0f2a09e/go.mod h1:pJLUxLENpZxwdsKMEsNbx1VGcRFpLqf3715MtcvvzbA=
|
||||
github.com/rivo/uniseg v0.1.0/go.mod h1:J6wj4VEh+S6ZtnVlnTBMWIodfgj8LQOQFoIToxlJtxc=
|
||||
github.com/rivo/uniseg v0.2.0/go.mod h1:J6wj4VEh+S6ZtnVlnTBMWIodfgj8LQOQFoIToxlJtxc=
|
||||
github.com/rivo/uniseg v0.4.7 h1:WUdvkW8uEhrYfLC4ZzdpI2ztxP1I582+49Oc5Mq64VQ=
|
||||
github.com/rivo/uniseg v0.4.7/go.mod h1:FN3SvrM+Zdj16jyLfmOkMNblXMcoc8DfTHruCPUcx88=
|
||||
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
|
||||
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
|
||||
github.com/rogpeppe/go-internal v1.9.0 h1:73kH8U+JUqXU8lRuOHeVHaa/SZPifC7BkcraZVejAe8=
|
||||
github.com/rogpeppe/go-internal v1.9.0/go.mod h1:WtVeX8xhTBvf0smdhujwtBcq4Qrzq/fJaraNFVN+nFs=
|
||||
github.com/russross/blackfriday/v2 v2.1.0/go.mod h1:+Rmxgy9KzJVeS9/2gXHxylqXiyQDYRxCVz55jmeOWTM=
|
||||
@ -55,7 +42,10 @@ github.com/spf13/pflag v1.0.5 h1:iy+VFUOCP1a+8yFto/drg2CJ5u0yRoB7fZw3DKv/JXA=
|
||||
github.com/spf13/pflag v1.0.5/go.mod h1:McXfInJRrz4CZXVZOBLb0bTZqETkiAhM9Iw0y3An2Bg=
|
||||
github.com/sqids/sqids-go v0.4.1 h1:eQKYzmAZbLlRwHeHYPF35QhgxwZHLnlmVj9AkIj/rrw=
|
||||
github.com/sqids/sqids-go v0.4.1/go.mod h1:EMwHuPQgSNFS0A49jESTfIQS+066XQTVhukrzEPScl8=
|
||||
golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
github.com/stretchr/testify v1.8.4 h1:CcVxjf3Q8PM0mHUKJCdn+eZZtm5yQwehR5yeSVQQcUk=
|
||||
github.com/stretchr/testify v1.8.4/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo=
|
||||
github.com/xo/terminfo v0.0.0-20210125001918-ca9a967f8778 h1:QldyIu/L63oPpyvQmHgvgickp1Yw510KJOqX7H24mg8=
|
||||
github.com/xo/terminfo v0.0.0-20210125001918-ca9a967f8778/go.mod h1:2MuV+tbUrU1zIOPMxZ5EncGwgmMJsa+9ucAQZXxsObs=
|
||||
golang.org/x/sys v0.14.0 h1:Vz7Qs629MkJkGyHxUlRHizWJRG2j8fbQKjELVSNhy7Q=
|
||||
golang.org/x/sys v0.14.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
|
||||
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
|
||||
@ -63,6 +53,7 @@ gopkg.in/check.v1 v1.0.0-20190902080502-41f04d3bba15 h1:YR8cESwS4TdDjEe65xsg0ogR
|
||||
gopkg.in/check.v1 v1.0.0-20190902080502-41f04d3bba15/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
|
||||
gopkg.in/yaml.v2 v2.2.2 h1:ZCJp+EgiOT7lHqUV2J862kp8Qj64Jo6az82+3Td9dZw=
|
||||
gopkg.in/yaml.v2 v2.2.2/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI=
|
||||
gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
|
||||
gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
|
||||
gorm.io/driver/sqlite v1.5.4 h1:IqXwXi8M/ZlPzH/947tn5uik3aYQslP9BVveoax0nV0=
|
||||
gorm.io/driver/sqlite v1.5.4/go.mod h1:qxAuCol+2r6PannQDpOP1FP6ag3mKi4esLnB/jHed+4=
|
||||
|
17
main.go
17
main.go
@ -1,18 +1,15 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"git.mlow.ca/mlow/lmcli/pkg/lmcli"
|
||||
"git.mlow.ca/mlow/lmcli/pkg/cmd"
|
||||
"fmt"
|
||||
"os"
|
||||
|
||||
"git.mlow.ca/mlow/lmcli/pkg/cli"
|
||||
)
|
||||
|
||||
func main() {
|
||||
ctx, err := lmcli.NewContext()
|
||||
if err != nil {
|
||||
lmcli.Fatal("%v\n", err)
|
||||
}
|
||||
|
||||
root := cmd.RootCmd(ctx)
|
||||
if err := root.Execute(); err != nil {
|
||||
lmcli.Fatal("%v\n", err)
|
||||
if err := cli.Execute(); err != nil {
|
||||
fmt.Fprintln(os.Stderr, err.Error())
|
||||
os.Exit(1)
|
||||
}
|
||||
}
|
||||
|
32
pkg/cli/cli.go
Normal file
32
pkg/cli/cli.go
Normal file
@ -0,0 +1,32 @@
|
||||
package cli
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"os"
|
||||
)
|
||||
|
||||
var config *Config
|
||||
var store *Store
|
||||
|
||||
func init() {
|
||||
var err error
|
||||
|
||||
config, err = NewConfig()
|
||||
if err != nil {
|
||||
Fatal("%v\n", err)
|
||||
}
|
||||
|
||||
store, err = NewStore()
|
||||
if err != nil {
|
||||
Fatal("%v\n", err)
|
||||
}
|
||||
}
|
||||
|
||||
func Fatal(format string, args ...any) {
|
||||
fmt.Fprintf(os.Stderr, format, args...)
|
||||
os.Exit(1)
|
||||
}
|
||||
|
||||
func Warn(format string, args ...any) {
|
||||
fmt.Fprintf(os.Stderr, format, args...)
|
||||
}
|
719
pkg/cli/cmd.go
Normal file
719
pkg/cli/cmd.go
Normal file
@ -0,0 +1,719 @@
|
||||
package cli
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"os"
|
||||
"slices"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/spf13/cobra"
|
||||
)
|
||||
|
||||
var (
|
||||
maxTokens int
|
||||
model string
|
||||
systemPrompt string
|
||||
systemPromptFile string
|
||||
)
|
||||
|
||||
const (
|
||||
// Limit number of conversations shown with `ls`, without --all
|
||||
LS_LIMIT int = 25
|
||||
)
|
||||
|
||||
func init() {
|
||||
inputCmds := []*cobra.Command{newCmd, promptCmd, replyCmd, retryCmd, continueCmd, editCmd}
|
||||
for _, cmd := range inputCmds {
|
||||
cmd.Flags().IntVar(&maxTokens, "length", *config.OpenAI.DefaultMaxLength, "Maximum response tokens")
|
||||
cmd.Flags().StringVar(&model, "model", *config.OpenAI.DefaultModel, "Which model to use model")
|
||||
cmd.Flags().StringVar(&systemPrompt, "system-prompt", *config.ModelDefaults.SystemPrompt, "System prompt")
|
||||
cmd.Flags().StringVar(&systemPromptFile, "system-prompt-file", "", "A path to a file containing the system prompt")
|
||||
cmd.MarkFlagsMutuallyExclusive("system-prompt", "system-prompt-file")
|
||||
}
|
||||
|
||||
listCmd.Flags().Bool("all", false, fmt.Sprintf("Show all conversations, by default only the last %d are shown", LS_LIMIT))
|
||||
renameCmd.Flags().Bool("generate", false, "Generate a conversation title")
|
||||
editCmd.Flags().Int("offset", 1, "Offset from the last reply to edit (Default: edit your last message, assuming there's an assistant reply)")
|
||||
|
||||
rootCmd.AddCommand(
|
||||
cloneCmd,
|
||||
continueCmd,
|
||||
editCmd,
|
||||
listCmd,
|
||||
newCmd,
|
||||
promptCmd,
|
||||
renameCmd,
|
||||
replyCmd,
|
||||
retryCmd,
|
||||
rmCmd,
|
||||
viewCmd,
|
||||
)
|
||||
}
|
||||
|
||||
func Execute() error {
|
||||
return rootCmd.Execute()
|
||||
}
|
||||
|
||||
func getSystemPrompt() string {
|
||||
if systemPromptFile != "" {
|
||||
content, err := FileContents(systemPromptFile)
|
||||
if err != nil {
|
||||
Fatal("Could not read file contents at %s: %v\n", systemPromptFile, err)
|
||||
}
|
||||
return content
|
||||
}
|
||||
return systemPrompt
|
||||
}
|
||||
|
||||
// fetchAndShowCompletion prompts the LLM with the given messages and streams
|
||||
// the response to stdout. Returns all model reply messages.
|
||||
func fetchAndShowCompletion(messages []Message) ([]Message, error) {
|
||||
content := make(chan string) // receives the reponse from LLM
|
||||
defer close(content)
|
||||
|
||||
// render all content received over the channel
|
||||
go ShowDelayedContent(content)
|
||||
|
||||
var replies []Message
|
||||
response, err := CreateChatCompletionStream(model, messages, maxTokens, content, &replies)
|
||||
if response != "" {
|
||||
// there was some content, so break to a new line after it
|
||||
fmt.Println()
|
||||
|
||||
if err != nil {
|
||||
Warn("Received partial response. Error: %v\n", err)
|
||||
err = nil
|
||||
}
|
||||
}
|
||||
|
||||
return replies, err
|
||||
}
|
||||
|
||||
// lookupConversation either returns the conversation found by the
|
||||
// short name or exits the program
|
||||
func lookupConversation(shortName string) *Conversation {
|
||||
c, err := store.ConversationByShortName(shortName)
|
||||
if err != nil {
|
||||
Fatal("Could not lookup conversation: %v\n", err)
|
||||
}
|
||||
if c.ID == 0 {
|
||||
Fatal("Conversation not found with short name: %s\n", shortName)
|
||||
}
|
||||
return c
|
||||
}
|
||||
|
||||
func lookupConversationE(shortName string) (*Conversation, error) {
|
||||
c, err := store.ConversationByShortName(shortName)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("Could not lookup conversation: %v", err)
|
||||
}
|
||||
if c.ID == 0 {
|
||||
return nil, fmt.Errorf("Conversation not found with short name: %s", shortName)
|
||||
}
|
||||
return c, nil
|
||||
}
|
||||
|
||||
// handleConversationReply handles sending messages to an existing
|
||||
// conversation, optionally persisting both the sent replies and responses.
|
||||
func handleConversationReply(c *Conversation, persist bool, toSend ...Message) {
|
||||
existing, err := store.Messages(c)
|
||||
if err != nil {
|
||||
Fatal("Could not retrieve messages for conversation: %s\n", c.Title)
|
||||
}
|
||||
|
||||
if persist {
|
||||
for _, message := range toSend {
|
||||
err = store.SaveMessage(&message)
|
||||
if err != nil {
|
||||
Warn("Could not save %s message: %v\n", message.Role, err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
allMessages := append(existing, toSend...)
|
||||
|
||||
RenderConversation(allMessages, true)
|
||||
|
||||
// render a message header with no contents
|
||||
(&Message{Role: MessageRoleAssistant}).RenderTTY()
|
||||
|
||||
replies, err := fetchAndShowCompletion(allMessages)
|
||||
if err != nil {
|
||||
Fatal("Error fetching LLM response: %v\n", err)
|
||||
}
|
||||
|
||||
if persist {
|
||||
for _, reply := range replies {
|
||||
reply.ConversationID = c.ID
|
||||
|
||||
err = store.SaveMessage(&reply)
|
||||
if err != nil {
|
||||
Warn("Could not save reply: %v\n", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// inputFromArgsOrEditor returns either the provided input from the args slice
|
||||
// (joined with spaces), or if len(args) is 0, opens an editor and returns
|
||||
// whatever input was provided there. placeholder is a string which populates
|
||||
// the editor and gets stripped from the final output.
|
||||
func inputFromArgsOrEditor(args []string, placeholder string, existingMessage string) (message string) {
|
||||
var err error
|
||||
if len(args) == 0 {
|
||||
message, err = InputFromEditor(placeholder, "message.*.md", existingMessage)
|
||||
if err != nil {
|
||||
Fatal("Failed to get input: %v\n", err)
|
||||
}
|
||||
} else {
|
||||
message = strings.Trim(strings.Join(args, " "), " \t\n")
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
var rootCmd = &cobra.Command{
|
||||
Use: "lmcli <command> [flags]",
|
||||
Long: `lmcli - Large Language Model CLI`,
|
||||
SilenceErrors: true,
|
||||
SilenceUsage: true,
|
||||
Run: func(cmd *cobra.Command, args []string) {
|
||||
cmd.Usage()
|
||||
},
|
||||
}
|
||||
|
||||
var listCmd = &cobra.Command{
|
||||
Use: "list",
|
||||
Aliases: []string{"ls"},
|
||||
Short: "List conversations",
|
||||
Long: `List conversations in order of recent activity`,
|
||||
Run: func(cmd *cobra.Command, args []string) {
|
||||
conversations, err := store.Conversations()
|
||||
if err != nil {
|
||||
Fatal("Could not fetch conversations.\n")
|
||||
return
|
||||
}
|
||||
|
||||
type Category struct {
|
||||
name string
|
||||
cutoff time.Duration
|
||||
}
|
||||
|
||||
type ConversationLine struct {
|
||||
timeSinceReply time.Duration
|
||||
formatted string
|
||||
}
|
||||
|
||||
now := time.Now()
|
||||
|
||||
midnight := time.Date(now.Year(), now.Month(), now.Day(), 0, 0, 0, 0, now.Location())
|
||||
monthStart := time.Date(now.Year(), now.Month(), 1, 0, 0, 0, 0, now.Location())
|
||||
dayOfWeek := int(now.Weekday())
|
||||
categories := []Category{
|
||||
{"today", now.Sub(midnight)},
|
||||
{"yesterday", now.Sub(midnight.AddDate(0, 0, -1))},
|
||||
{"this week", now.Sub(midnight.AddDate(0, 0, -dayOfWeek))},
|
||||
{"last week", now.Sub(midnight.AddDate(0, 0, -(dayOfWeek + 7)))},
|
||||
{"this month", now.Sub(monthStart)},
|
||||
{"last month", now.Sub(monthStart.AddDate(0, -1, 0))},
|
||||
{"2 months ago", now.Sub(monthStart.AddDate(0, -2, 0))},
|
||||
{"3 months ago", now.Sub(monthStart.AddDate(0, -3, 0))},
|
||||
{"4 months ago", now.Sub(monthStart.AddDate(0, -4, 0))},
|
||||
{"5 months ago", now.Sub(monthStart.AddDate(0, -5, 0))},
|
||||
{"older", now.Sub(time.Time{})},
|
||||
}
|
||||
categorized := map[string][]ConversationLine{}
|
||||
|
||||
all, _ := cmd.Flags().GetBool("all")
|
||||
|
||||
for _, conversation := range conversations {
|
||||
lastMessage, err := store.LastMessage(&conversation)
|
||||
if lastMessage == nil || err != nil {
|
||||
continue
|
||||
}
|
||||
|
||||
messageAge := now.Sub(lastMessage.CreatedAt)
|
||||
|
||||
var category string
|
||||
for _, c := range categories {
|
||||
if messageAge < c.cutoff {
|
||||
category = c.name
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
formatted := fmt.Sprintf(
|
||||
"%s - %s - %s",
|
||||
conversation.ShortName.String,
|
||||
humanTimeElapsedSince(messageAge),
|
||||
conversation.Title,
|
||||
)
|
||||
|
||||
categorized[category] = append(
|
||||
categorized[category],
|
||||
ConversationLine{messageAge, formatted},
|
||||
)
|
||||
}
|
||||
|
||||
var conversationsPrinted int
|
||||
outer:
|
||||
for _, category := range categories {
|
||||
conversations, ok := categorized[category.name]
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
|
||||
slices.SortFunc(conversations, func(a, b ConversationLine) int {
|
||||
return int(a.timeSinceReply - b.timeSinceReply)
|
||||
})
|
||||
|
||||
fmt.Printf("%s:\n", category.name)
|
||||
for _, conv := range conversations {
|
||||
if conversationsPrinted >= LS_LIMIT && !all {
|
||||
fmt.Printf("%d remaining message(s), use --all to view.\n", len(conversations)-conversationsPrinted)
|
||||
break outer
|
||||
}
|
||||
|
||||
fmt.Printf(" %s\n", conv.formatted)
|
||||
conversationsPrinted++
|
||||
}
|
||||
}
|
||||
},
|
||||
}
|
||||
|
||||
var rmCmd = &cobra.Command{
|
||||
Use: "rm <conversation>...",
|
||||
Short: "Remove conversations",
|
||||
Long: `Remove conversations by their short names.`,
|
||||
Args: func(cmd *cobra.Command, args []string) error {
|
||||
argCount := 1
|
||||
if err := cobra.MinimumNArgs(argCount)(cmd, args); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
},
|
||||
Run: func(cmd *cobra.Command, args []string) {
|
||||
var toRemove []*Conversation
|
||||
for _, shortName := range args {
|
||||
conversation := lookupConversation(shortName)
|
||||
toRemove = append(toRemove, conversation)
|
||||
}
|
||||
var errors []error
|
||||
for _, c := range toRemove {
|
||||
err := store.DeleteConversation(c)
|
||||
if err != nil {
|
||||
errors = append(errors, fmt.Errorf("Could not remove conversation %s: %v", c.ShortName.String, err))
|
||||
}
|
||||
}
|
||||
for _, err := range errors {
|
||||
fmt.Fprintln(os.Stderr, err.Error())
|
||||
}
|
||||
if len(errors) > 0 {
|
||||
os.Exit(1)
|
||||
}
|
||||
},
|
||||
ValidArgsFunction: func(cmd *cobra.Command, args []string, toComplete string) ([]string, cobra.ShellCompDirective) {
|
||||
compMode := cobra.ShellCompDirectiveNoFileComp
|
||||
var completions []string
|
||||
outer:
|
||||
for _, completion := range store.ConversationShortNameCompletions(toComplete) {
|
||||
parts := strings.Split(completion, "\t")
|
||||
for _, arg := range args {
|
||||
if parts[0] == arg {
|
||||
continue outer
|
||||
}
|
||||
}
|
||||
completions = append(completions, completion)
|
||||
}
|
||||
return completions, compMode
|
||||
},
|
||||
}
|
||||
|
||||
var cloneCmd = &cobra.Command{
|
||||
Use: "clone <conversation>",
|
||||
Short: "Clone conversations",
|
||||
Long: `Clones the provided conversation.`,
|
||||
Args: func(cmd *cobra.Command, args []string) error {
|
||||
argCount := 1
|
||||
if err := cobra.MinimumNArgs(argCount)(cmd, args); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
},
|
||||
RunE: func(cmd *cobra.Command, args []string) error {
|
||||
shortName := args[0]
|
||||
toClone, err := lookupConversationE(shortName)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
messagesToCopy, err := store.Messages(toClone)
|
||||
if err != nil {
|
||||
return fmt.Errorf("Could not retrieve messages for conversation: %s", toClone.ShortName.String)
|
||||
}
|
||||
|
||||
clone := &Conversation{
|
||||
Title: toClone.Title + " - Clone",
|
||||
}
|
||||
if err := store.SaveConversation(clone); err != nil {
|
||||
return fmt.Errorf("Cloud not create clone: %s", err)
|
||||
}
|
||||
|
||||
var errors []error
|
||||
messageCnt := 0
|
||||
for _, message := range messagesToCopy {
|
||||
newMessage := message
|
||||
newMessage.ConversationID = clone.ID
|
||||
newMessage.ID = 0
|
||||
if err := store.SaveMessage(&newMessage); err != nil {
|
||||
errors = append(errors, err)
|
||||
} else {
|
||||
messageCnt++
|
||||
}
|
||||
}
|
||||
|
||||
if len(errors) > 0 {
|
||||
return fmt.Errorf("Messages failed to be cloned: %v", errors)
|
||||
}
|
||||
|
||||
fmt.Printf("Cloned %d messages to: %s\n", messageCnt, clone.Title)
|
||||
return nil
|
||||
},
|
||||
ValidArgsFunction: func(cmd *cobra.Command, args []string, toComplete string) ([]string, cobra.ShellCompDirective) {
|
||||
compMode := cobra.ShellCompDirectiveNoFileComp
|
||||
if len(args) != 0 {
|
||||
return nil, compMode
|
||||
}
|
||||
return store.ConversationShortNameCompletions(toComplete), compMode
|
||||
},
|
||||
}
|
||||
|
||||
var viewCmd = &cobra.Command{
|
||||
Use: "view <conversation>",
|
||||
Short: "View messages in a conversation",
|
||||
Long: `Finds a conversation by its short name and displays its contents.`,
|
||||
Args: func(cmd *cobra.Command, args []string) error {
|
||||
argCount := 1
|
||||
if err := cobra.MinimumNArgs(argCount)(cmd, args); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
},
|
||||
Run: func(cmd *cobra.Command, args []string) {
|
||||
shortName := args[0]
|
||||
conversation := lookupConversation(shortName)
|
||||
|
||||
messages, err := store.Messages(conversation)
|
||||
if err != nil {
|
||||
Fatal("Could not retrieve messages for conversation: %s\n", conversation.Title)
|
||||
}
|
||||
|
||||
RenderConversation(messages, false)
|
||||
},
|
||||
ValidArgsFunction: func(cmd *cobra.Command, args []string, toComplete string) ([]string, cobra.ShellCompDirective) {
|
||||
compMode := cobra.ShellCompDirectiveNoFileComp
|
||||
if len(args) != 0 {
|
||||
return nil, compMode
|
||||
}
|
||||
return store.ConversationShortNameCompletions(toComplete), compMode
|
||||
},
|
||||
}
|
||||
|
||||
var renameCmd = &cobra.Command{
|
||||
Use: "rename <conversation> [title]",
|
||||
Short: "Rename a conversation",
|
||||
Long: `Renames a conversation, either with the provided title or by generating a new name.`,
|
||||
Args: func(cmd *cobra.Command, args []string) error {
|
||||
argCount := 1
|
||||
if err := cobra.MinimumNArgs(argCount)(cmd, args); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
},
|
||||
Run: func(cmd *cobra.Command, args []string) {
|
||||
shortName := args[0]
|
||||
conversation := lookupConversation(shortName)
|
||||
var err error
|
||||
|
||||
generate, _ := cmd.Flags().GetBool("generate")
|
||||
var title string
|
||||
if generate {
|
||||
title, err = conversation.GenerateTitle()
|
||||
if err != nil {
|
||||
Fatal("Could not generate conversation title: %v\n", err)
|
||||
}
|
||||
} else {
|
||||
if len(args) < 2 {
|
||||
Fatal("Conversation title not provided.\n")
|
||||
}
|
||||
title = strings.Join(args[1:], " ")
|
||||
}
|
||||
|
||||
conversation.Title = title
|
||||
err = store.SaveConversation(conversation)
|
||||
if err != nil {
|
||||
Warn("Could not save conversation with new title: %v\n", err)
|
||||
}
|
||||
},
|
||||
ValidArgsFunction: func(cmd *cobra.Command, args []string, toComplete string) ([]string, cobra.ShellCompDirective) {
|
||||
compMode := cobra.ShellCompDirectiveNoFileComp
|
||||
if len(args) != 0 {
|
||||
return nil, compMode
|
||||
}
|
||||
return store.ConversationShortNameCompletions(toComplete), compMode
|
||||
},
|
||||
}
|
||||
|
||||
var replyCmd = &cobra.Command{
|
||||
Use: "reply <conversation> [message]",
|
||||
Short: "Reply to a conversation",
|
||||
Long: `Sends a reply to conversation and writes the response to stdout.`,
|
||||
Args: func(cmd *cobra.Command, args []string) error {
|
||||
argCount := 1
|
||||
if err := cobra.MinimumNArgs(argCount)(cmd, args); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
},
|
||||
Run: func(cmd *cobra.Command, args []string) {
|
||||
shortName := args[0]
|
||||
conversation := lookupConversation(shortName)
|
||||
|
||||
reply := inputFromArgsOrEditor(args[1:], "# How would you like to reply?\n", "")
|
||||
if reply == "" {
|
||||
Fatal("No reply was provided.\n")
|
||||
}
|
||||
|
||||
handleConversationReply(conversation, true, Message{
|
||||
ConversationID: conversation.ID,
|
||||
Role: MessageRoleUser,
|
||||
OriginalContent: reply,
|
||||
})
|
||||
},
|
||||
ValidArgsFunction: func(cmd *cobra.Command, args []string, toComplete string) ([]string, cobra.ShellCompDirective) {
|
||||
compMode := cobra.ShellCompDirectiveNoFileComp
|
||||
if len(args) != 0 {
|
||||
return nil, compMode
|
||||
}
|
||||
return store.ConversationShortNameCompletions(toComplete), compMode
|
||||
},
|
||||
}
|
||||
|
||||
var newCmd = &cobra.Command{
|
||||
Use: "new [message]",
|
||||
Short: "Start a new conversation",
|
||||
Long: `Start a new conversation with the Large Language Model.`,
|
||||
Run: func(cmd *cobra.Command, args []string) {
|
||||
messageContents := inputFromArgsOrEditor(args, "# What would you like to say?\n", "")
|
||||
if messageContents == "" {
|
||||
Fatal("No message was provided.\n")
|
||||
}
|
||||
|
||||
conversation := &Conversation{}
|
||||
err := store.SaveConversation(conversation)
|
||||
if err != nil {
|
||||
Fatal("Could not save new conversation: %v\n", err)
|
||||
}
|
||||
|
||||
messages := []Message{
|
||||
{
|
||||
ConversationID: conversation.ID,
|
||||
Role: MessageRoleSystem,
|
||||
OriginalContent: getSystemPrompt(),
|
||||
},
|
||||
{
|
||||
ConversationID: conversation.ID,
|
||||
Role: MessageRoleUser,
|
||||
OriginalContent: messageContents,
|
||||
},
|
||||
}
|
||||
|
||||
handleConversationReply(conversation, true, messages...)
|
||||
|
||||
title, err := conversation.GenerateTitle()
|
||||
if err != nil {
|
||||
Warn("Could not generate title for conversation: %v\n", err)
|
||||
}
|
||||
|
||||
conversation.Title = title
|
||||
|
||||
err = store.SaveConversation(conversation)
|
||||
if err != nil {
|
||||
Warn("Could not save conversation after generating title: %v\n", err)
|
||||
}
|
||||
},
|
||||
}
|
||||
|
||||
var promptCmd = &cobra.Command{
|
||||
Use: "prompt [message]",
|
||||
Short: "Do a one-shot prompt",
|
||||
Long: `Prompt the Large Language Model and get a response.`,
|
||||
Run: func(cmd *cobra.Command, args []string) {
|
||||
message := inputFromArgsOrEditor(args, "# What would you like to say?\n", "")
|
||||
if message == "" {
|
||||
Fatal("No message was provided.\n")
|
||||
}
|
||||
|
||||
messages := []Message{
|
||||
{
|
||||
Role: MessageRoleSystem,
|
||||
OriginalContent: getSystemPrompt(),
|
||||
},
|
||||
{
|
||||
Role: MessageRoleUser,
|
||||
OriginalContent: message,
|
||||
},
|
||||
}
|
||||
|
||||
_, err := fetchAndShowCompletion(messages)
|
||||
if err != nil {
|
||||
Fatal("Error fetching LLM response: %v\n", err)
|
||||
}
|
||||
},
|
||||
}
|
||||
|
||||
var retryCmd = &cobra.Command{
|
||||
Use: "retry <conversation>",
|
||||
Short: "Retry the last user reply in a conversation",
|
||||
Long: `Re-prompt the conversation up to the last user response. Can be used to regenerate the last assistant reply, or simply generate one if an error occurred.`,
|
||||
Args: func(cmd *cobra.Command, args []string) error {
|
||||
argCount := 1
|
||||
if err := cobra.MinimumNArgs(argCount)(cmd, args); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
},
|
||||
Run: func(cmd *cobra.Command, args []string) {
|
||||
shortName := args[0]
|
||||
conversation := lookupConversation(shortName)
|
||||
|
||||
messages, err := store.Messages(conversation)
|
||||
if err != nil {
|
||||
Fatal("Could not retrieve messages for conversation: %s\n", conversation.Title)
|
||||
}
|
||||
|
||||
// walk backwards through the conversation and delete messages, break
|
||||
// when we find the latest user response
|
||||
for i := len(messages) - 1; i >= 0; i-- {
|
||||
if messages[i].Role == MessageRoleUser {
|
||||
break
|
||||
}
|
||||
|
||||
err = store.DeleteMessage(&messages[i])
|
||||
if err != nil {
|
||||
Warn("Could not delete previous reply: %v\n", err)
|
||||
}
|
||||
}
|
||||
|
||||
handleConversationReply(conversation, true)
|
||||
},
|
||||
ValidArgsFunction: func(cmd *cobra.Command, args []string, toComplete string) ([]string, cobra.ShellCompDirective) {
|
||||
compMode := cobra.ShellCompDirectiveNoFileComp
|
||||
if len(args) != 0 {
|
||||
return nil, compMode
|
||||
}
|
||||
return store.ConversationShortNameCompletions(toComplete), compMode
|
||||
},
|
||||
}
|
||||
|
||||
var continueCmd = &cobra.Command{
|
||||
Use: "continue <conversation>",
|
||||
Short: "Continue a conversation from the last message",
|
||||
Long: `Re-prompt the conversation with all existing prompts. Useful if a reply was cut short.`,
|
||||
Args: func(cmd *cobra.Command, args []string) error {
|
||||
argCount := 1
|
||||
if err := cobra.MinimumNArgs(argCount)(cmd, args); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
},
|
||||
Run: func(cmd *cobra.Command, args []string) {
|
||||
shortName := args[0]
|
||||
conversation := lookupConversation(shortName)
|
||||
handleConversationReply(conversation, true)
|
||||
},
|
||||
ValidArgsFunction: func(cmd *cobra.Command, args []string, toComplete string) ([]string, cobra.ShellCompDirective) {
|
||||
compMode := cobra.ShellCompDirectiveNoFileComp
|
||||
if len(args) != 0 {
|
||||
return nil, compMode
|
||||
}
|
||||
return store.ConversationShortNameCompletions(toComplete), compMode
|
||||
},
|
||||
}
|
||||
|
||||
var editCmd = &cobra.Command{
|
||||
Use: "edit <conversation>",
|
||||
Short: "Edit the last user reply in a conversation",
|
||||
Args: func(cmd *cobra.Command, args []string) error {
|
||||
argCount := 1
|
||||
if err := cobra.MinimumNArgs(argCount)(cmd, args); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
},
|
||||
Run: func(cmd *cobra.Command, args []string) {
|
||||
shortName := args[0]
|
||||
conversation := lookupConversation(shortName)
|
||||
|
||||
messages, err := store.Messages(conversation)
|
||||
if err != nil {
|
||||
Fatal("Could not retrieve messages for conversation: %s\n", conversation.Title)
|
||||
}
|
||||
|
||||
offset, _ := cmd.Flags().GetInt("offset")
|
||||
if offset < 0 {
|
||||
offset = -offset
|
||||
}
|
||||
|
||||
if offset > len(messages) - 1 {
|
||||
Fatal("Offset %d is before the start of the conversation\n", offset)
|
||||
}
|
||||
|
||||
desiredIdx := len(messages) - 1 - offset
|
||||
|
||||
// walk backwards through the conversation deleting messages until and
|
||||
// including the last user message
|
||||
toRemove := []Message{}
|
||||
var toEdit *Message
|
||||
for i := len(messages) - 1; i >= 0; i-- {
|
||||
if i == desiredIdx {
|
||||
toEdit = &messages[i]
|
||||
}
|
||||
toRemove = append(toRemove, messages[i])
|
||||
messages = messages[:i]
|
||||
if toEdit != nil {
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
existingContents := toEdit.OriginalContent
|
||||
|
||||
newContents := inputFromArgsOrEditor(args[1:], "# Save when finished editing\n", existingContents)
|
||||
switch newContents {
|
||||
case existingContents:
|
||||
Fatal("No edits were made.\n")
|
||||
case "":
|
||||
Fatal("No message was provided.\n")
|
||||
}
|
||||
|
||||
for _, message := range toRemove {
|
||||
err = store.DeleteMessage(&message)
|
||||
if err != nil {
|
||||
Warn("Could not delete message: %v\n", err)
|
||||
}
|
||||
}
|
||||
|
||||
handleConversationReply(conversation, true, Message{
|
||||
ConversationID: conversation.ID,
|
||||
Role: MessageRoleUser,
|
||||
OriginalContent: newContents,
|
||||
})
|
||||
},
|
||||
ValidArgsFunction: func(cmd *cobra.Command, args []string, toComplete string) ([]string, cobra.ShellCompDirective) {
|
||||
compMode := cobra.ShellCompDirectiveNoFileComp
|
||||
if len(args) != 0 {
|
||||
return nil, compMode
|
||||
}
|
||||
return store.ConversationShortNameCompletions(toComplete), compMode
|
||||
},
|
||||
}
|
@ -1,41 +1,46 @@
|
||||
package lmcli
|
||||
package cli
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
|
||||
"git.mlow.ca/mlow/lmcli/pkg/util"
|
||||
"github.com/go-yaml/yaml"
|
||||
)
|
||||
|
||||
type Config struct {
|
||||
Defaults *struct {
|
||||
ModelDefaults *struct {
|
||||
SystemPrompt *string `yaml:"systemPrompt" default:"You are a helpful assistant."`
|
||||
MaxTokens *int `yaml:"maxTokens" default:"256"`
|
||||
Temperature *float32 `yaml:"temperature" default:"0.7"`
|
||||
Model *string `yaml:"model" default:"gpt-4"`
|
||||
} `yaml:"defaults"`
|
||||
Conversations *struct {
|
||||
TitleGenerationModel *string `yaml:"titleGenerationModel" default:"gpt-3.5-turbo"`
|
||||
} `yaml:"conversations"`
|
||||
Tools *struct {
|
||||
EnabledTools *[]string `yaml:"enabledTools"`
|
||||
} `yaml:"tools"`
|
||||
} `yaml:"modelDefaults"`
|
||||
OpenAI *struct {
|
||||
APIKey *string `yaml:"apiKey" default:"your_key_here"`
|
||||
Models *[]string `yaml:"models"`
|
||||
DefaultModel *string `yaml:"defaultModel" default:"gpt-4"`
|
||||
DefaultMaxLength *int `yaml:"defaultMaxLength" default:"256"`
|
||||
EnabledTools []string `yaml:"enabledTools"`
|
||||
} `yaml:"openai"`
|
||||
Anthropic *struct {
|
||||
APIKey *string `yaml:"apiKey" default:"your_key_here"`
|
||||
Models *[]string `yaml:"models"`
|
||||
} `yaml:"anthropic"`
|
||||
Chroma *struct {
|
||||
Style *string `yaml:"style" default:"onedark"`
|
||||
Formatter *string `yaml:"formatter" default:"terminal16m"`
|
||||
} `yaml:"chroma"`
|
||||
}
|
||||
|
||||
func NewConfig(configFile string) (*Config, error) {
|
||||
func configDir() string {
|
||||
var configDir string
|
||||
|
||||
xdgConfigHome := os.Getenv("XDG_CONFIG_HOME")
|
||||
if xdgConfigHome != "" {
|
||||
configDir = filepath.Join(xdgConfigHome, "lmcli")
|
||||
} else {
|
||||
userHomeDir, _ := os.UserHomeDir()
|
||||
configDir = filepath.Join(userHomeDir, ".config/lmcli")
|
||||
}
|
||||
|
||||
os.MkdirAll(configDir, 0755)
|
||||
return configDir
|
||||
}
|
||||
|
||||
func NewConfig() (*Config, error) {
|
||||
configFile := filepath.Join(configDir(), "config.yaml")
|
||||
shouldWriteDefaults := false
|
||||
c := &Config{}
|
||||
|
||||
@ -49,11 +54,11 @@ func NewConfig(configFile string) (*Config, error) {
|
||||
yaml.Unmarshal(configBytes, c)
|
||||
}
|
||||
|
||||
shouldWriteDefaults = util.SetStructDefaults(c)
|
||||
shouldWriteDefaults = SetStructDefaults(c)
|
||||
if !configExists || shouldWriteDefaults {
|
||||
if configExists {
|
||||
fmt.Printf("Saving new defaults to configuration, backing up existing configuration to %s\n", configFile+".bak")
|
||||
os.Rename(configFile, configFile+".bak")
|
||||
fmt.Printf("Saving new defaults to configuration, backing up existing configuration to %s\n", configFile + ".bak")
|
||||
os.Rename(configFile, configFile + ".bak")
|
||||
}
|
||||
fmt.Printf("Writing configuration file to %s\n", configFile)
|
||||
file, err := os.Create(configFile)
|
67
pkg/cli/conversation.go
Normal file
67
pkg/cli/conversation.go
Normal file
@ -0,0 +1,67 @@
|
||||
package cli
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strings"
|
||||
)
|
||||
|
||||
type MessageRole string
|
||||
|
||||
const (
|
||||
MessageRoleUser MessageRole = "user"
|
||||
MessageRoleAssistant MessageRole = "assistant"
|
||||
MessageRoleSystem MessageRole = "system"
|
||||
)
|
||||
|
||||
// FriendlyRole returns a human friendly signifier for the message's role.
|
||||
func (m *Message) FriendlyRole() string {
|
||||
var friendlyRole string
|
||||
switch m.Role {
|
||||
case MessageRoleUser:
|
||||
friendlyRole = "You"
|
||||
case MessageRoleSystem:
|
||||
friendlyRole = "System"
|
||||
case MessageRoleAssistant:
|
||||
friendlyRole = "Assistant"
|
||||
default:
|
||||
friendlyRole = string(m.Role)
|
||||
}
|
||||
return friendlyRole
|
||||
}
|
||||
|
||||
func (c *Conversation) GenerateTitle() (string, error) {
|
||||
messages, err := store.Messages(c)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
const header = "Generate a concise 4-5 word title for the conversation below."
|
||||
prompt := fmt.Sprintf("%s\n\n---\n\n%s", header, formatForExternalPrompting(messages, false))
|
||||
|
||||
generateRequest := []Message{
|
||||
{
|
||||
Role: MessageRoleUser,
|
||||
OriginalContent: prompt,
|
||||
},
|
||||
}
|
||||
|
||||
model := "gpt-3.5-turbo" // use cheap model to generate title
|
||||
response, err := CreateChatCompletion(model, generateRequest, 25, nil)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
return response, nil
|
||||
}
|
||||
|
||||
func formatForExternalPrompting(messages []Message, system bool) string {
|
||||
sb := strings.Builder{}
|
||||
for _, message := range messages {
|
||||
if message.Role == MessageRoleSystem && !system {
|
||||
continue
|
||||
}
|
||||
sb.WriteString(fmt.Sprintf("<%s>\n", message.FriendlyRole()))
|
||||
sb.WriteString(fmt.Sprintf("\"\"\"\n%s\n\"\"\"\n\n", message.OriginalContent))
|
||||
}
|
||||
return sb.String()
|
||||
}
|
582
pkg/cli/functions.go
Normal file
582
pkg/cli/functions.go
Normal file
@ -0,0 +1,582 @@
|
||||
package cli
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
|
||||
openai "github.com/sashabaranov/go-openai"
|
||||
)
|
||||
|
||||
type FunctionResult struct {
|
||||
Message string `json:"message"`
|
||||
Result any `json:"result,omitempty"`
|
||||
}
|
||||
|
||||
type FunctionParameter struct {
|
||||
Type string `json:"type"` // "string", "integer", "boolean"
|
||||
Description string `json:"description"`
|
||||
Enum []string `json:"enum,omitempty"`
|
||||
}
|
||||
|
||||
type FunctionParameters struct {
|
||||
Type string `json:"type"` // "object"
|
||||
Properties map[string]FunctionParameter `json:"properties"`
|
||||
Required []string `json:"required,omitempty"` // required function parameter names
|
||||
}
|
||||
|
||||
type AvailableTool struct {
|
||||
openai.Tool
|
||||
// The tool's implementation. Returns a string, as tool call results
|
||||
// are treated as normal messages with string contents.
|
||||
Impl func(arguments map[string]interface{}) (string, error)
|
||||
}
|
||||
|
||||
const (
|
||||
READ_DIR_DESCRIPTION = `Return the contents of the CWD (current working directory).
|
||||
|
||||
Results are returned as JSON in the following format:
|
||||
{
|
||||
"message": "success", // if successful, or a different message indicating failure
|
||||
// result may be an empty array if there are no files in the directory
|
||||
"result": [
|
||||
{"name": "a_file", "type": "file", "size": 123},
|
||||
{"name": "a_directory/", "type": "dir", "size": 11},
|
||||
... // more files or directories
|
||||
]
|
||||
}
|
||||
|
||||
For files, size represents the size (in bytes) of the file.
|
||||
For directories, size represents the number of entries in that directory.`
|
||||
|
||||
READ_FILE_DESCRIPTION = `Read the contents of a text file relative to the current working directory.
|
||||
|
||||
Each line of the file is prefixed with its line number and a tabs (\t) to make
|
||||
it make it easier to see which lines to change for other modifications.
|
||||
|
||||
Example result:
|
||||
{
|
||||
"message": "success", // if successful, or a different message indicating failure
|
||||
"result": "1\tthe contents\n2\tof the file\n"
|
||||
}`
|
||||
|
||||
WRITE_FILE_DESCRIPTION = `Write the provided contents to a file relative to the current working directory.
|
||||
|
||||
Note: only use this tool when you've been explicitly asked to create or write to a file.
|
||||
|
||||
When using this function, you do not need to share the content you intend to write with the user first.
|
||||
|
||||
Example result:
|
||||
{
|
||||
"message": "success", // if successful, or a different message indicating failure
|
||||
}`
|
||||
|
||||
FILE_INSERT_LINES_DESCRIPTION = `Insert lines into a file, must specify path.
|
||||
|
||||
Make sure your inserts match the flow and indentation of surrounding content.`
|
||||
|
||||
FILE_REPLACE_LINES_DESCRIPTION = `Replace or remove a range of lines within a file, must specify path.
|
||||
|
||||
Useful for re-writing snippets/blocks of code or entire functions.
|
||||
|
||||
Be cautious with your edits. When replacing, ensure the replacement content matches the flow and indentation of surrounding content.`
|
||||
)
|
||||
|
||||
var AvailableTools = map[string]AvailableTool{
|
||||
"read_dir": {
|
||||
Tool: openai.Tool{Type: "function", Function: openai.FunctionDefinition{
|
||||
Name: "read_dir",
|
||||
Description: READ_DIR_DESCRIPTION,
|
||||
Parameters: FunctionParameters{
|
||||
Type: "object",
|
||||
Properties: map[string]FunctionParameter{
|
||||
"relative_dir": {
|
||||
Type: "string",
|
||||
Description: "If set, read the contents of a directory relative to the current one.",
|
||||
},
|
||||
},
|
||||
},
|
||||
}},
|
||||
Impl: func(args map[string]interface{}) (string, error) {
|
||||
var relativeDir string
|
||||
tmp, ok := args["relative_dir"]
|
||||
if ok {
|
||||
relativeDir, ok = tmp.(string)
|
||||
if !ok {
|
||||
return "", fmt.Errorf("Invalid relative_dir in function arguments: %v", tmp)
|
||||
}
|
||||
}
|
||||
return ReadDir(relativeDir), nil
|
||||
},
|
||||
},
|
||||
"read_file": {
|
||||
Tool: openai.Tool{Type: "function", Function: openai.FunctionDefinition{
|
||||
Name: "read_file",
|
||||
Description: READ_FILE_DESCRIPTION,
|
||||
Parameters: FunctionParameters{
|
||||
Type: "object",
|
||||
Properties: map[string]FunctionParameter{
|
||||
"path": {
|
||||
Type: "string",
|
||||
Description: "Path to a file within the current working directory to read.",
|
||||
},
|
||||
},
|
||||
Required: []string{"path"},
|
||||
},
|
||||
}},
|
||||
Impl: func(args map[string]interface{}) (string, error) {
|
||||
tmp, ok := args["path"]
|
||||
if !ok {
|
||||
return "", fmt.Errorf("Path parameter to read_file was not included.")
|
||||
}
|
||||
path, ok := tmp.(string)
|
||||
if !ok {
|
||||
return "", fmt.Errorf("Invalid path in function arguments: %v", tmp)
|
||||
}
|
||||
return ReadFile(path), nil
|
||||
},
|
||||
},
|
||||
"write_file": {
|
||||
Tool: openai.Tool{Type: "function", Function: openai.FunctionDefinition{
|
||||
Name: "write_file",
|
||||
Description: WRITE_FILE_DESCRIPTION,
|
||||
Parameters: FunctionParameters{
|
||||
Type: "object",
|
||||
Properties: map[string]FunctionParameter{
|
||||
"path": {
|
||||
Type: "string",
|
||||
Description: "Path to a file within the current working directory to write to.",
|
||||
},
|
||||
"content": {
|
||||
Type: "string",
|
||||
Description: "The content to write to the file. Overwrites any existing content!",
|
||||
},
|
||||
},
|
||||
Required: []string{"path", "content"},
|
||||
},
|
||||
}},
|
||||
Impl: func(args map[string]interface{}) (string, error) {
|
||||
tmp, ok := args["path"]
|
||||
if !ok {
|
||||
return "", fmt.Errorf("Path parameter to write_file was not included.")
|
||||
}
|
||||
path, ok := tmp.(string)
|
||||
if !ok {
|
||||
return "", fmt.Errorf("Invalid path in function arguments: %v", tmp)
|
||||
}
|
||||
tmp, ok = args["content"]
|
||||
if !ok {
|
||||
return "", fmt.Errorf("Content parameter to write_file was not included.")
|
||||
}
|
||||
content, ok := tmp.(string)
|
||||
if !ok {
|
||||
return "", fmt.Errorf("Invalid content in function arguments: %v", tmp)
|
||||
}
|
||||
return WriteFile(path, content), nil
|
||||
},
|
||||
},
|
||||
"file_insert_lines": {
|
||||
Tool: openai.Tool{Type: "function", Function: openai.FunctionDefinition{
|
||||
Name: "file_insert_lines",
|
||||
Description: FILE_INSERT_LINES_DESCRIPTION,
|
||||
Parameters: FunctionParameters{
|
||||
Type: "object",
|
||||
Properties: map[string]FunctionParameter{
|
||||
"path": {
|
||||
Type: "string",
|
||||
Description: "Path of the file to be modified, relative to the current working directory.",
|
||||
},
|
||||
"position": {
|
||||
Type: "integer",
|
||||
Description: `Which line to insert content *before*.`,
|
||||
},
|
||||
"content": {
|
||||
Type: "string",
|
||||
Description: `The content to insert.`,
|
||||
},
|
||||
},
|
||||
Required: []string{"path", "position", "content"},
|
||||
},
|
||||
}},
|
||||
Impl: func(args map[string]interface{}) (string, error) {
|
||||
tmp, ok := args["path"]
|
||||
if !ok {
|
||||
return "", fmt.Errorf("path parameter to write_file was not included.")
|
||||
}
|
||||
path, ok := tmp.(string)
|
||||
if !ok {
|
||||
return "", fmt.Errorf("Invalid path in function arguments: %v", tmp)
|
||||
}
|
||||
var position int
|
||||
tmp, ok = args["position"]
|
||||
if ok {
|
||||
tmp, ok := tmp.(float64)
|
||||
if !ok {
|
||||
return "", fmt.Errorf("Invalid position in function arguments: %v", tmp)
|
||||
}
|
||||
position = int(tmp)
|
||||
}
|
||||
var content string
|
||||
tmp, ok = args["content"]
|
||||
if ok {
|
||||
content, ok = tmp.(string)
|
||||
if !ok {
|
||||
return "", fmt.Errorf("Invalid content in function arguments: %v", tmp)
|
||||
}
|
||||
}
|
||||
return FileInsertLines(path, position, content), nil
|
||||
},
|
||||
},
|
||||
"file_replace_lines": {
|
||||
Tool: openai.Tool{Type: "function", Function: openai.FunctionDefinition{
|
||||
Name: "file_replace_lines",
|
||||
Description: FILE_REPLACE_LINES_DESCRIPTION,
|
||||
Parameters: FunctionParameters{
|
||||
Type: "object",
|
||||
Properties: map[string]FunctionParameter{
|
||||
"path": {
|
||||
Type: "string",
|
||||
Description: "Path of the file to be modified, relative to the current working directory.",
|
||||
},
|
||||
"start_line": {
|
||||
Type: "integer",
|
||||
Description: `Line number which specifies the start of the replacement range (inclusive).`,
|
||||
},
|
||||
"end_line": {
|
||||
Type: "integer",
|
||||
Description: `Line number which specifies the end of the replacement range (inclusive). If unset, range extends to end of file.`,
|
||||
},
|
||||
"content": {
|
||||
Type: "string",
|
||||
Description: `Content to replace specified range. Omit to remove the specified range.`,
|
||||
},
|
||||
},
|
||||
Required: []string{"path", "start_line"},
|
||||
},
|
||||
}},
|
||||
Impl: func(args map[string]interface{}) (string, error) {
|
||||
tmp, ok := args["path"]
|
||||
if !ok {
|
||||
return "", fmt.Errorf("path parameter to write_file was not included.")
|
||||
}
|
||||
path, ok := tmp.(string)
|
||||
if !ok {
|
||||
return "", fmt.Errorf("Invalid path in function arguments: %v", tmp)
|
||||
}
|
||||
var start_line int
|
||||
tmp, ok = args["start_line"]
|
||||
if ok {
|
||||
tmp, ok := tmp.(float64)
|
||||
if !ok {
|
||||
return "", fmt.Errorf("Invalid start_line in function arguments: %v", tmp)
|
||||
}
|
||||
start_line = int(tmp)
|
||||
}
|
||||
var end_line int
|
||||
tmp, ok = args["end_line"]
|
||||
if ok {
|
||||
tmp, ok := tmp.(float64)
|
||||
if !ok {
|
||||
return "", fmt.Errorf("Invalid end_line in function arguments: %v", tmp)
|
||||
}
|
||||
end_line = int(tmp)
|
||||
}
|
||||
var content string
|
||||
tmp, ok = args["content"]
|
||||
if ok {
|
||||
content, ok = tmp.(string)
|
||||
if !ok {
|
||||
return "", fmt.Errorf("Invalid content in function arguments: %v", tmp)
|
||||
}
|
||||
}
|
||||
|
||||
return FileReplaceLines(path, start_line, end_line, content), nil
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
func resultToJson(result FunctionResult) string {
|
||||
if result.Message == "" {
|
||||
// When message not supplied, assume success
|
||||
result.Message = "success"
|
||||
}
|
||||
|
||||
jsonBytes, err := json.Marshal(result)
|
||||
if err != nil {
|
||||
fmt.Printf("Could not marshal FunctionResult to JSON: %v\n", err)
|
||||
}
|
||||
return string(jsonBytes)
|
||||
}
|
||||
|
||||
// ExecuteToolCalls handles the execution of all tool_calls provided, and
|
||||
// returns their results formatted as []Message(s) with role: 'tool' and.
|
||||
func ExecuteToolCalls(toolCalls []openai.ToolCall) ([]Message, error) {
|
||||
var toolResults []Message
|
||||
for _, toolCall := range toolCalls {
|
||||
if toolCall.Type != "function" {
|
||||
// unsupported tool type
|
||||
continue
|
||||
}
|
||||
|
||||
tool, ok := AvailableTools[toolCall.Function.Name]
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("Requested tool '%s' does not exist. Hallucination?", toolCall.Function.Name)
|
||||
}
|
||||
|
||||
var functionArgs map[string]interface{}
|
||||
err := json.Unmarshal([]byte(toolCall.Function.Arguments), &functionArgs)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("Could not unmarshal tool arguments. Malformed JSON? Error: %v", err)
|
||||
}
|
||||
|
||||
// TODO: ability to silence this
|
||||
fmt.Fprintf(os.Stderr, "INFO: Executing tool '%s' with args %s\n", toolCall.Function.Name, toolCall.Function.Arguments)
|
||||
|
||||
// Execute the tool
|
||||
toolResult, err := tool.Impl(functionArgs)
|
||||
if err != nil {
|
||||
// This can happen if the model missed or supplied invalid tool args
|
||||
return nil, fmt.Errorf("Tool '%s' error: %v\n", toolCall.Function.Name, err)
|
||||
}
|
||||
|
||||
toolResults = append(toolResults, Message{
|
||||
Role: "tool",
|
||||
OriginalContent: toolResult,
|
||||
ToolCallID: sql.NullString{String: toolCall.ID, Valid: true},
|
||||
// name is not required since the introduction of ToolCallID
|
||||
// hypothesis: by setting it, we inform the model of what a
|
||||
// function's purpose was if future requests omit the function
|
||||
// definition
|
||||
})
|
||||
}
|
||||
return toolResults, nil
|
||||
}
|
||||
|
||||
// isPathContained attempts to verify whether `path` is the same as or
|
||||
// contained within `directory`. It is overly cautious, returning false even if
|
||||
// `path` IS contained within `directory`, but the two paths use different
|
||||
// casing, and we happen to be on a case-insensitive filesystem.
|
||||
// This is ultimately to attempt to stop an LLM from going outside of where I
|
||||
// tell it to. Additional layers of security should be considered.. run in a
|
||||
// VM/container.
|
||||
func isPathContained(directory string, path string) (bool, error) {
|
||||
// Clean and resolve symlinks for both paths
|
||||
path, err := filepath.Abs(path)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
|
||||
// check if path exists
|
||||
_, err = os.Stat(path)
|
||||
if err != nil {
|
||||
if !os.IsNotExist(err) {
|
||||
return false, fmt.Errorf("Could not stat path: %v", err)
|
||||
}
|
||||
} else {
|
||||
path, err = filepath.EvalSymlinks(path)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
}
|
||||
|
||||
directory, err = filepath.Abs(directory)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
directory, err = filepath.EvalSymlinks(directory)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
|
||||
// Case insensitive checks
|
||||
if !strings.EqualFold(path, directory) &&
|
||||
!strings.HasPrefix(strings.ToLower(path), strings.ToLower(directory)+string(os.PathSeparator)) {
|
||||
return false, nil
|
||||
}
|
||||
|
||||
return true, nil
|
||||
}
|
||||
|
||||
func isPathWithinCWD(path string) (bool, *FunctionResult) {
|
||||
cwd, err := os.Getwd()
|
||||
if err != nil {
|
||||
return false, &FunctionResult{Message: "Failed to determine current working directory"}
|
||||
}
|
||||
if ok, err := isPathContained(cwd, path); !ok {
|
||||
if err != nil {
|
||||
return false, &FunctionResult{Message: fmt.Sprintf("Could not determine whether path '%s' is within the current working directory: %s", path, err.Error())}
|
||||
}
|
||||
return false, &FunctionResult{Message: fmt.Sprintf("Path '%s' is not within the current working directory", path)}
|
||||
}
|
||||
return true, nil
|
||||
}
|
||||
|
||||
func ReadDir(path string) string {
|
||||
// TODO(?): implement whitelist - list of directories which model is allowed to work in
|
||||
if path == "" {
|
||||
path = "."
|
||||
}
|
||||
ok, res := isPathWithinCWD(path)
|
||||
if !ok {
|
||||
return resultToJson(*res)
|
||||
}
|
||||
|
||||
files, err := os.ReadDir(path)
|
||||
if err != nil {
|
||||
return resultToJson(FunctionResult{
|
||||
Message: err.Error(),
|
||||
})
|
||||
}
|
||||
|
||||
var dirContents []map[string]interface{}
|
||||
for _, f := range files {
|
||||
info, _ := f.Info()
|
||||
|
||||
name := f.Name()
|
||||
if strings.HasPrefix(name, ".") {
|
||||
// skip hidden files
|
||||
continue
|
||||
}
|
||||
|
||||
entryType := "file"
|
||||
size := info.Size()
|
||||
|
||||
if info.IsDir() {
|
||||
name += "/"
|
||||
entryType = "dir"
|
||||
subdirfiles, _ := os.ReadDir(filepath.Join(".", path, info.Name()))
|
||||
size = int64(len(subdirfiles))
|
||||
}
|
||||
|
||||
dirContents = append(dirContents, map[string]interface{}{
|
||||
"name": name,
|
||||
"type": entryType,
|
||||
"size": size,
|
||||
})
|
||||
}
|
||||
|
||||
return resultToJson(FunctionResult{Result: dirContents})
|
||||
}
|
||||
|
||||
func ReadFile(path string) string {
|
||||
ok, res := isPathWithinCWD(path)
|
||||
if !ok {
|
||||
return resultToJson(*res)
|
||||
}
|
||||
data, err := os.ReadFile(path)
|
||||
if err != nil {
|
||||
return resultToJson(FunctionResult{Message: fmt.Sprintf("Could not read path: %s", err.Error())})
|
||||
}
|
||||
|
||||
lines := strings.Split(string(data), "\n")
|
||||
content := strings.Builder{}
|
||||
for i, line := range lines {
|
||||
content.WriteString(fmt.Sprintf("%d\t%s\n", i+1, line))
|
||||
}
|
||||
|
||||
return resultToJson(FunctionResult{
|
||||
Result: content.String(),
|
||||
})
|
||||
}
|
||||
|
||||
func WriteFile(path string, content string) string {
|
||||
ok, res := isPathWithinCWD(path)
|
||||
if !ok {
|
||||
return resultToJson(*res)
|
||||
}
|
||||
err := os.WriteFile(path, []byte(content), 0644)
|
||||
if err != nil {
|
||||
return resultToJson(FunctionResult{Message: fmt.Sprintf("Could not write to path: %s", err.Error())})
|
||||
}
|
||||
return resultToJson(FunctionResult{})
|
||||
}
|
||||
|
||||
func FileInsertLines(path string, position int, content string) string {
|
||||
ok, res := isPathWithinCWD(path)
|
||||
if !ok {
|
||||
return resultToJson(*res)
|
||||
}
|
||||
|
||||
// Read the existing file's content
|
||||
data, err := os.ReadFile(path)
|
||||
if err != nil {
|
||||
if !os.IsNotExist(err) {
|
||||
return resultToJson(FunctionResult{Message: fmt.Sprintf("Could not read path: %s", err.Error())})
|
||||
}
|
||||
_, err = os.Create(path)
|
||||
if err != nil {
|
||||
return resultToJson(FunctionResult{Message: fmt.Sprintf("Could not create new file: %s", err.Error())})
|
||||
}
|
||||
data = []byte{}
|
||||
}
|
||||
|
||||
if position < 1 {
|
||||
return resultToJson(FunctionResult{Message: "start_line cannot be less than 1"})
|
||||
}
|
||||
|
||||
lines := strings.Split(string(data), "\n")
|
||||
contentLines := strings.Split(strings.Trim(content, "\n"), "\n")
|
||||
|
||||
before := lines[:position-1]
|
||||
after := lines[position-1:]
|
||||
lines = append(before, append(contentLines, after...)...)
|
||||
|
||||
newContent := strings.Join(lines, "\n")
|
||||
|
||||
// Join the lines and write back to the file
|
||||
err = os.WriteFile(path, []byte(newContent), 0644)
|
||||
if err != nil {
|
||||
return resultToJson(FunctionResult{Message: fmt.Sprintf("Could not write to path: %s", err.Error())})
|
||||
}
|
||||
|
||||
return resultToJson(FunctionResult{Result: newContent})
|
||||
}
|
||||
|
||||
func FileReplaceLines(path string, startLine int, endLine int, content string) string {
|
||||
ok, res := isPathWithinCWD(path)
|
||||
if !ok {
|
||||
return resultToJson(*res)
|
||||
}
|
||||
|
||||
// Read the existing file's content
|
||||
data, err := os.ReadFile(path)
|
||||
if err != nil {
|
||||
if !os.IsNotExist(err) {
|
||||
return resultToJson(FunctionResult{Message: fmt.Sprintf("Could not read path: %s", err.Error())})
|
||||
}
|
||||
_, err = os.Create(path)
|
||||
if err != nil {
|
||||
return resultToJson(FunctionResult{Message: fmt.Sprintf("Could not create new file: %s", err.Error())})
|
||||
}
|
||||
data = []byte{}
|
||||
}
|
||||
|
||||
if startLine < 1 {
|
||||
return resultToJson(FunctionResult{Message: "start_line cannot be less than 1"})
|
||||
}
|
||||
|
||||
lines := strings.Split(string(data), "\n")
|
||||
contentLines := strings.Split(strings.Trim(content, "\n"), "\n")
|
||||
|
||||
if endLine == 0 || endLine > len(lines) {
|
||||
endLine = len(lines)
|
||||
}
|
||||
|
||||
before := lines[:startLine-1]
|
||||
after := lines[endLine:]
|
||||
|
||||
lines = append(before, append(contentLines, after...)...)
|
||||
newContent := strings.Join(lines, "\n")
|
||||
|
||||
// Join the lines and write back to the file
|
||||
err = os.WriteFile(path, []byte(newContent), 0644)
|
||||
if err != nil {
|
||||
return resultToJson(FunctionResult{Message: fmt.Sprintf("Could not write to path: %s", err.Error())})
|
||||
}
|
||||
|
||||
return resultToJson(FunctionResult{Result: newContent})
|
||||
|
||||
}
|
187
pkg/cli/openai.go
Normal file
187
pkg/cli/openai.go
Normal file
@ -0,0 +1,187 @@
|
||||
package cli
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"strings"
|
||||
|
||||
openai "github.com/sashabaranov/go-openai"
|
||||
)
|
||||
|
||||
func CreateChatCompletionRequest(model string, messages []Message, maxTokens int) openai.ChatCompletionRequest {
|
||||
chatCompletionMessages := []openai.ChatCompletionMessage{}
|
||||
for _, m := range messages {
|
||||
message := openai.ChatCompletionMessage{
|
||||
Role: string(m.Role),
|
||||
Content: m.OriginalContent,
|
||||
}
|
||||
if m.ToolCallID.Valid {
|
||||
message.ToolCallID = m.ToolCallID.String
|
||||
}
|
||||
if m.ToolCalls.Valid {
|
||||
// unmarshal directly into chatMessage.ToolCalls
|
||||
err := json.Unmarshal([]byte(m.ToolCalls.String), &message.ToolCalls)
|
||||
if err != nil {
|
||||
// TODO: handle, this shouldn't really happen since
|
||||
// we only save the successfully marshal'd data to database
|
||||
fmt.Printf("Error unmarshalling the tool_calls JSON: %v\n", err)
|
||||
}
|
||||
}
|
||||
chatCompletionMessages = append(chatCompletionMessages, message)
|
||||
}
|
||||
|
||||
request := openai.ChatCompletionRequest{
|
||||
Model: model,
|
||||
Messages: chatCompletionMessages,
|
||||
MaxTokens: maxTokens,
|
||||
N: 1, // limit responses to 1 "choice". we use choices[0] to reference it
|
||||
}
|
||||
|
||||
var tools []openai.Tool
|
||||
for _, t := range config.OpenAI.EnabledTools {
|
||||
tool, ok := AvailableTools[t]
|
||||
if ok {
|
||||
tools = append(tools, tool.Tool)
|
||||
}
|
||||
}
|
||||
|
||||
if len(tools) > 0 {
|
||||
request.Tools = tools
|
||||
request.ToolChoice = "auto"
|
||||
}
|
||||
|
||||
return request
|
||||
}
|
||||
|
||||
// CreateChatCompletion submits a Chat Completion API request and returns the
|
||||
// response. CreateChatCompletion will recursively call itself in the case of
|
||||
// tool calls, until a response is received with the final user-facing output.
|
||||
func CreateChatCompletion(model string, messages []Message, maxTokens int, replies *[]Message) (string, error) {
|
||||
client := openai.NewClient(*config.OpenAI.APIKey)
|
||||
req := CreateChatCompletionRequest(model, messages, maxTokens)
|
||||
resp, err := client.CreateChatCompletion(context.Background(), req)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
choice := resp.Choices[0]
|
||||
|
||||
if len(choice.Message.ToolCalls) > 0 {
|
||||
// Append the assistant's reply with its request for tool calls
|
||||
toolCallJson, _ := json.Marshal(choice.Message.ToolCalls)
|
||||
assistantReply := Message{
|
||||
Role: MessageRoleAssistant,
|
||||
ToolCalls: sql.NullString{String: string(toolCallJson), Valid: true},
|
||||
}
|
||||
|
||||
toolReplies, err := ExecuteToolCalls(choice.Message.ToolCalls)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
if replies != nil {
|
||||
*replies = append(append(*replies, assistantReply), toolReplies...)
|
||||
}
|
||||
|
||||
messages = append(append(messages, assistantReply), toolReplies...)
|
||||
// Recurse into CreateChatCompletion with the tool call replies added
|
||||
// to the original messages
|
||||
return CreateChatCompletion(model, messages, maxTokens, replies)
|
||||
}
|
||||
|
||||
if replies != nil {
|
||||
*replies = append(*replies, Message{
|
||||
Role: MessageRoleAssistant,
|
||||
OriginalContent: choice.Message.Content,
|
||||
})
|
||||
}
|
||||
|
||||
// Return the user-facing message.
|
||||
return choice.Message.Content, nil
|
||||
}
|
||||
|
||||
// CreateChatCompletionStream submits a streaming Chat Completion API request
|
||||
// and both returns and streams the response to the provided output channel.
|
||||
// May return a partial response if an error occurs mid-stream.
|
||||
func CreateChatCompletionStream(model string, messages []Message, maxTokens int, output chan<- string, replies *[]Message) (string, error) {
|
||||
client := openai.NewClient(*config.OpenAI.APIKey)
|
||||
req := CreateChatCompletionRequest(model, messages, maxTokens)
|
||||
|
||||
stream, err := client.CreateChatCompletionStream(context.Background(), req)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
defer stream.Close()
|
||||
|
||||
content := strings.Builder{}
|
||||
toolCalls := []openai.ToolCall{}
|
||||
|
||||
// Iterate stream segments
|
||||
for {
|
||||
response, e := stream.Recv()
|
||||
if errors.Is(e, io.EOF) {
|
||||
break
|
||||
}
|
||||
|
||||
if e != nil {
|
||||
err = e
|
||||
break
|
||||
}
|
||||
|
||||
delta := response.Choices[0].Delta
|
||||
if len(delta.ToolCalls) > 0 {
|
||||
// Construct streamed tool_call arguments
|
||||
for _, tc := range delta.ToolCalls {
|
||||
if tc.Index == nil {
|
||||
return "", fmt.Errorf("Unexpected nil index for streamed tool call.")
|
||||
}
|
||||
if len(toolCalls) <= *tc.Index {
|
||||
toolCalls = append(toolCalls, tc)
|
||||
} else {
|
||||
toolCalls[*tc.Index].Function.Arguments += tc.Function.Arguments
|
||||
}
|
||||
}
|
||||
} else {
|
||||
output <- delta.Content
|
||||
content.WriteString(delta.Content)
|
||||
}
|
||||
}
|
||||
|
||||
if len(toolCalls) > 0 {
|
||||
// Append the assistant's reply with its request for tool calls
|
||||
toolCallJson, _ := json.Marshal(toolCalls)
|
||||
|
||||
assistantReply := Message{
|
||||
Role: MessageRoleAssistant,
|
||||
OriginalContent: content.String(),
|
||||
ToolCalls: sql.NullString{String: string(toolCallJson), Valid: true},
|
||||
}
|
||||
|
||||
toolReplies, err := ExecuteToolCalls(toolCalls)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
if replies != nil {
|
||||
*replies = append(append(*replies, assistantReply), toolReplies...)
|
||||
}
|
||||
|
||||
// Recurse into CreateChatCompletionStream with the tool call replies
|
||||
// added to the original messages
|
||||
messages = append(append(messages, assistantReply), toolReplies...)
|
||||
return CreateChatCompletionStream(model, messages, maxTokens, output, replies)
|
||||
}
|
||||
|
||||
if replies != nil {
|
||||
*replies = append(*replies, Message{
|
||||
Role: MessageRoleAssistant,
|
||||
OriginalContent: content.String(),
|
||||
})
|
||||
}
|
||||
|
||||
return content.String(), err
|
||||
}
|
141
pkg/cli/store.go
Normal file
141
pkg/cli/store.go
Normal file
@ -0,0 +1,141 @@
|
||||
package cli
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
sqids "github.com/sqids/sqids-go"
|
||||
"gorm.io/driver/sqlite"
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
type Store struct {
|
||||
db *gorm.DB
|
||||
sqids *sqids.Sqids
|
||||
}
|
||||
|
||||
type Message struct {
|
||||
ID uint `gorm:"primaryKey"`
|
||||
ConversationID uint `gorm:"foreignKey:ConversationID"`
|
||||
Conversation Conversation
|
||||
OriginalContent string
|
||||
Role MessageRole // one of: 'system', 'user', 'assistant', 'tool'
|
||||
CreatedAt time.Time
|
||||
ToolCallID sql.NullString
|
||||
ToolCalls sql.NullString // a json-encoded array of tool calls from the model
|
||||
}
|
||||
|
||||
type Conversation struct {
|
||||
ID uint `gorm:"primaryKey"`
|
||||
ShortName sql.NullString
|
||||
Title string
|
||||
}
|
||||
|
||||
func dataDir() string {
|
||||
var dataDir string
|
||||
|
||||
xdgDataHome := os.Getenv("XDG_DATA_HOME")
|
||||
if xdgDataHome != "" {
|
||||
dataDir = filepath.Join(xdgDataHome, "lmcli")
|
||||
} else {
|
||||
userHomeDir, _ := os.UserHomeDir()
|
||||
dataDir = filepath.Join(userHomeDir, ".local/share/lmcli")
|
||||
}
|
||||
|
||||
os.MkdirAll(dataDir, 0755)
|
||||
return dataDir
|
||||
}
|
||||
|
||||
func NewStore() (*Store, error) {
|
||||
databaseFile := filepath.Join(dataDir(), "conversations.db")
|
||||
db, err := gorm.Open(sqlite.Open(databaseFile), &gorm.Config{})
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("Error establishing connection to store: %v", err)
|
||||
}
|
||||
|
||||
models := []any{
|
||||
&Conversation{},
|
||||
&Message{},
|
||||
}
|
||||
|
||||
for _, x := range models {
|
||||
err := db.AutoMigrate(x)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("Could not perform database migrations: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
_sqids, _ := sqids.New(sqids.Options{MinLength: 4})
|
||||
return &Store{db, _sqids}, nil
|
||||
}
|
||||
|
||||
func (s *Store) SaveConversation(conversation *Conversation) error {
|
||||
err := s.db.Save(&conversation).Error
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if !conversation.ShortName.Valid {
|
||||
shortName, _ := s.sqids.Encode([]uint64{uint64(conversation.ID)})
|
||||
conversation.ShortName = sql.NullString{String: shortName, Valid: true}
|
||||
err = s.db.Save(&conversation).Error
|
||||
}
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
func (s *Store) DeleteConversation(conversation *Conversation) error {
|
||||
s.db.Where("conversation_id = ?", conversation.ID).Delete(&Message{})
|
||||
return s.db.Delete(&conversation).Error
|
||||
}
|
||||
|
||||
func (s *Store) SaveMessage(message *Message) error {
|
||||
return s.db.Create(message).Error
|
||||
}
|
||||
|
||||
func (s *Store) DeleteMessage(message *Message) error {
|
||||
return s.db.Delete(&message).Error
|
||||
}
|
||||
|
||||
func (s *Store) Conversations() ([]Conversation, error) {
|
||||
var conversations []Conversation
|
||||
err := s.db.Find(&conversations).Error
|
||||
return conversations, err
|
||||
}
|
||||
|
||||
func (s *Store) ConversationShortNameCompletions(shortName string) []string {
|
||||
var completions []string
|
||||
conversations, _ := s.Conversations() // ignore error for completions
|
||||
for _, conversation := range conversations {
|
||||
if shortName == "" || strings.HasPrefix(conversation.ShortName.String, shortName) {
|
||||
completions = append(completions, fmt.Sprintf("%s\t%s", conversation.ShortName.String, conversation.Title))
|
||||
}
|
||||
}
|
||||
return completions
|
||||
}
|
||||
|
||||
func (s *Store) ConversationByShortName(shortName string) (*Conversation, error) {
|
||||
if shortName == "" {
|
||||
return nil, errors.New("shortName is empty")
|
||||
}
|
||||
var conversation Conversation
|
||||
err := s.db.Where("short_name = ?", shortName).Find(&conversation).Error
|
||||
return &conversation, err
|
||||
}
|
||||
|
||||
func (s *Store) Messages(conversation *Conversation) ([]Message, error) {
|
||||
var messages []Message
|
||||
err := s.db.Where("conversation_id = ?", conversation.ID).Find(&messages).Error
|
||||
return messages, err
|
||||
}
|
||||
|
||||
func (s *Store) LastMessage(conversation *Conversation) (*Message, error) {
|
||||
var message Message
|
||||
err := s.db.Where("conversation_id = ?", conversation.ID).Last(&message).Error
|
||||
return &message, err
|
||||
}
|
113
pkg/cli/tty.go
Normal file
113
pkg/cli/tty.go
Normal file
@ -0,0 +1,113 @@
|
||||
package cli
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"os"
|
||||
"time"
|
||||
|
||||
"github.com/alecthomas/chroma/v2/quick"
|
||||
"github.com/gookit/color"
|
||||
)
|
||||
|
||||
// ShowWaitAnimation "draws" an animated ellipses to stdout until something is
|
||||
// received on the signal channel. An empty string sent to the channel to
|
||||
// noftify the caller that the animation has completed (carriage returned).
|
||||
func ShowWaitAnimation(signal chan any) {
|
||||
animationStep := 0
|
||||
for {
|
||||
select {
|
||||
case _ = <-signal:
|
||||
fmt.Print("\r")
|
||||
signal <- ""
|
||||
return
|
||||
default:
|
||||
modSix := animationStep % 6
|
||||
if modSix == 3 || modSix == 0 {
|
||||
fmt.Print("\r")
|
||||
}
|
||||
if modSix < 3 {
|
||||
fmt.Print(".")
|
||||
} else {
|
||||
fmt.Print(" ")
|
||||
}
|
||||
animationStep++
|
||||
time.Sleep(250 * time.Millisecond)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// ShowDelayedContent displays a waiting animation to stdout while waiting
|
||||
// for content to be received on the provided channel. As soon as any (possibly
|
||||
// chunked) content is received on the channel, the waiting animation is
|
||||
// replaced by the content.
|
||||
// Blocks until the channel is closed.
|
||||
func ShowDelayedContent(content <-chan string) {
|
||||
waitSignal := make(chan any)
|
||||
go ShowWaitAnimation(waitSignal)
|
||||
|
||||
firstChunk := true
|
||||
for chunk := range content {
|
||||
if firstChunk {
|
||||
// notify wait animation that we've received data
|
||||
waitSignal <- ""
|
||||
// wait for signal that wait animation has completed
|
||||
<-waitSignal
|
||||
firstChunk = false
|
||||
}
|
||||
fmt.Print(chunk)
|
||||
}
|
||||
}
|
||||
|
||||
// RenderConversation renders the given messages to TTY, with optional space
|
||||
// for a subsequent message. spaceForResponse controls how many '\n' characters
|
||||
// are printed immediately after the final message (1 if false, 2 if true)
|
||||
func RenderConversation(messages []Message, spaceForResponse bool) {
|
||||
l := len(messages)
|
||||
for i, message := range messages {
|
||||
message.RenderTTY()
|
||||
if i < l-1 || spaceForResponse {
|
||||
// print an additional space before the next message
|
||||
fmt.Println()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// HighlightMarkdown applies syntax highlighting to the provided markdown text
|
||||
// and writes it to stdout.
|
||||
func HighlightMarkdown(markdownText string) error {
|
||||
return quick.Highlight(os.Stdout, markdownText, "md", *config.Chroma.Formatter, *config.Chroma.Style)
|
||||
}
|
||||
|
||||
func (m *Message) RenderTTY() {
|
||||
var messageAge string
|
||||
if m.CreatedAt.IsZero() {
|
||||
messageAge = "now"
|
||||
} else {
|
||||
now := time.Now()
|
||||
messageAge = humanTimeElapsedSince(now.Sub(m.CreatedAt))
|
||||
}
|
||||
|
||||
var roleStyle color.Style
|
||||
switch m.Role {
|
||||
case MessageRoleSystem:
|
||||
roleStyle = color.Style{color.HiRed}
|
||||
case MessageRoleUser:
|
||||
roleStyle = color.Style{color.HiGreen}
|
||||
case MessageRoleAssistant:
|
||||
roleStyle = color.Style{color.HiBlue}
|
||||
default:
|
||||
roleStyle = color.Style{color.FgWhite}
|
||||
}
|
||||
roleStyle.Add(color.Bold)
|
||||
|
||||
headerColor := color.FgYellow
|
||||
separator := headerColor.Sprint("===")
|
||||
timestamp := headerColor.Sprint(messageAge)
|
||||
role := roleStyle.Sprint(m.FriendlyRole())
|
||||
|
||||
fmt.Printf("%s %s - %s %s\n\n", separator, role, timestamp, separator)
|
||||
if m.OriginalContent != "" {
|
||||
HighlightMarkdown(m.OriginalContent)
|
||||
fmt.Println()
|
||||
}
|
||||
}
|
@ -1,4 +1,4 @@
|
||||
package util
|
||||
package cli
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
@ -56,7 +56,7 @@ func InputFromEditor(placeholder string, pattern string, content string) (string
|
||||
|
||||
// humanTimeElapsedSince returns a human-friendly "in the past" representation
|
||||
// of the given duration.
|
||||
func HumanTimeElapsedSince(d time.Duration) string {
|
||||
func humanTimeElapsedSince(d time.Duration) string {
|
||||
seconds := d.Seconds()
|
||||
minutes := seconds / 60
|
||||
hours := minutes / 60
|
||||
@ -151,14 +151,6 @@ func SetStructDefaults(data interface{}) bool {
|
||||
intValue, _ := strconv.ParseInt(defaultTag, 10, 64)
|
||||
field.Set(reflect.New(e))
|
||||
field.Elem().SetInt(intValue)
|
||||
case reflect.Float32:
|
||||
floatValue, _ := strconv.ParseFloat(defaultTag, 32)
|
||||
field.Set(reflect.New(e))
|
||||
field.Elem().SetFloat(floatValue)
|
||||
case reflect.Float64:
|
||||
floatValue, _ := strconv.ParseFloat(defaultTag, 64)
|
||||
field.Set(reflect.New(e))
|
||||
field.Elem().SetFloat(floatValue)
|
||||
case reflect.Bool:
|
||||
boolValue := defaultTag == "true"
|
||||
field.Set(reflect.ValueOf(&boolValue))
|
||||
@ -168,8 +160,10 @@ func SetStructDefaults(data interface{}) bool {
|
||||
return changed
|
||||
}
|
||||
|
||||
// ReadFileContents returns the string contents of the given file.
|
||||
func ReadFileContents(file string) (string, error) {
|
||||
// FileContents returns the string contents of the given file.
|
||||
// TODO: we should support retrieving the content (or an approximation of)
|
||||
// non-text documents, e.g. PDFs.
|
||||
func FileContents(file string) (string, error) {
|
||||
path := filepath.Clean(file)
|
||||
content, err := os.ReadFile(path)
|
||||
if err != nil {
|
@ -1,72 +0,0 @@
|
||||
package cmd
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
cmdutil "git.mlow.ca/mlow/lmcli/pkg/cmd/util"
|
||||
"git.mlow.ca/mlow/lmcli/pkg/lmcli"
|
||||
"git.mlow.ca/mlow/lmcli/pkg/lmcli/model"
|
||||
"github.com/spf13/cobra"
|
||||
)
|
||||
|
||||
func CloneCmd(ctx *lmcli.Context) *cobra.Command {
|
||||
cmd := &cobra.Command{
|
||||
Use: "clone <conversation>",
|
||||
Short: "Clone conversations",
|
||||
Long: `Clones the provided conversation.`,
|
||||
Args: func(cmd *cobra.Command, args []string) error {
|
||||
argCount := 1
|
||||
if err := cobra.MinimumNArgs(argCount)(cmd, args); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
},
|
||||
RunE: func(cmd *cobra.Command, args []string) error {
|
||||
shortName := args[0]
|
||||
toClone, err := cmdutil.LookupConversationE(ctx, shortName)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
messagesToCopy, err := ctx.Store.Messages(toClone)
|
||||
if err != nil {
|
||||
return fmt.Errorf("Could not retrieve messages for conversation: %s", toClone.ShortName.String)
|
||||
}
|
||||
|
||||
clone := &model.Conversation{
|
||||
Title: toClone.Title + " - Clone",
|
||||
}
|
||||
if err := ctx.Store.SaveConversation(clone); err != nil {
|
||||
return fmt.Errorf("Cloud not create clone: %s", err)
|
||||
}
|
||||
|
||||
var errors []error
|
||||
messageCnt := 0
|
||||
for _, message := range messagesToCopy {
|
||||
newMessage := message
|
||||
newMessage.ConversationID = clone.ID
|
||||
newMessage.ID = 0
|
||||
if err := ctx.Store.SaveMessage(&newMessage); err != nil {
|
||||
errors = append(errors, err)
|
||||
} else {
|
||||
messageCnt++
|
||||
}
|
||||
}
|
||||
|
||||
if len(errors) > 0 {
|
||||
return fmt.Errorf("Messages failed to be cloned: %v", errors)
|
||||
}
|
||||
|
||||
fmt.Printf("Cloned %d messages to: %s\n", messageCnt, clone.Title)
|
||||
return nil
|
||||
},
|
||||
ValidArgsFunction: func(cmd *cobra.Command, args []string, toComplete string) ([]string, cobra.ShellCompDirective) {
|
||||
compMode := cobra.ShellCompDirectiveNoFileComp
|
||||
if len(args) != 0 {
|
||||
return nil, compMode
|
||||
}
|
||||
return ctx.Store.ConversationShortNameCompletions(toComplete), compMode
|
||||
},
|
||||
}
|
||||
return cmd
|
||||
}
|
@ -1,93 +0,0 @@
|
||||
package cmd
|
||||
|
||||
import (
|
||||
"strings"
|
||||
|
||||
"git.mlow.ca/mlow/lmcli/pkg/lmcli"
|
||||
"git.mlow.ca/mlow/lmcli/pkg/util"
|
||||
"github.com/spf13/cobra"
|
||||
)
|
||||
|
||||
var (
|
||||
systemPromptFile string
|
||||
)
|
||||
|
||||
func RootCmd(ctx *lmcli.Context) *cobra.Command {
|
||||
var root = &cobra.Command{
|
||||
Use: "lmcli <command> [flags]",
|
||||
Long: `lmcli - Large Language Model CLI`,
|
||||
SilenceErrors: true,
|
||||
SilenceUsage: true,
|
||||
Run: func(cmd *cobra.Command, args []string) {
|
||||
cmd.Usage()
|
||||
},
|
||||
}
|
||||
|
||||
continueCmd := ContinueCmd(ctx)
|
||||
cloneCmd := CloneCmd(ctx)
|
||||
editCmd := EditCmd(ctx)
|
||||
listCmd := ListCmd(ctx)
|
||||
newCmd := NewCmd(ctx)
|
||||
promptCmd := PromptCmd(ctx)
|
||||
renameCmd := RenameCmd(ctx)
|
||||
replyCmd := ReplyCmd(ctx)
|
||||
retryCmd := RetryCmd(ctx)
|
||||
rmCmd := RemoveCmd(ctx)
|
||||
viewCmd := ViewCmd(ctx)
|
||||
|
||||
inputCmds := []*cobra.Command{newCmd, promptCmd, replyCmd, retryCmd, continueCmd, editCmd}
|
||||
for _, cmd := range inputCmds {
|
||||
cmd.Flags().StringVar(ctx.Config.Defaults.Model, "model", *ctx.Config.Defaults.Model, "Which model to use")
|
||||
cmd.Flags().IntVar(ctx.Config.Defaults.MaxTokens, "length", *ctx.Config.Defaults.MaxTokens, "Maximum response tokens")
|
||||
cmd.Flags().StringVar(ctx.Config.Defaults.SystemPrompt, "system-prompt", *ctx.Config.Defaults.SystemPrompt, "System prompt")
|
||||
cmd.Flags().StringVar(&systemPromptFile, "system-prompt-file", "", "A path to a file containing the system prompt")
|
||||
cmd.MarkFlagsMutuallyExclusive("system-prompt", "system-prompt-file")
|
||||
}
|
||||
|
||||
renameCmd.Flags().Bool("generate", false, "Generate a conversation title")
|
||||
|
||||
root.AddCommand(
|
||||
cloneCmd,
|
||||
continueCmd,
|
||||
editCmd,
|
||||
listCmd,
|
||||
newCmd,
|
||||
promptCmd,
|
||||
renameCmd,
|
||||
replyCmd,
|
||||
retryCmd,
|
||||
rmCmd,
|
||||
viewCmd,
|
||||
)
|
||||
|
||||
return root
|
||||
}
|
||||
|
||||
func getSystemPrompt(ctx *lmcli.Context) string {
|
||||
if systemPromptFile != "" {
|
||||
content, err := util.ReadFileContents(systemPromptFile)
|
||||
if err != nil {
|
||||
lmcli.Fatal("Could not read file contents at %s: %v\n", systemPromptFile, err)
|
||||
}
|
||||
return content
|
||||
}
|
||||
return *ctx.Config.Defaults.SystemPrompt
|
||||
}
|
||||
|
||||
// inputFromArgsOrEditor returns either the provided input from the args slice
|
||||
// (joined with spaces), or if len(args) is 0, opens an editor and returns
|
||||
// whatever input was provided there. placeholder is a string which populates
|
||||
// the editor and gets stripped from the final output.
|
||||
func inputFromArgsOrEditor(args []string, placeholder string, existingMessage string) (message string) {
|
||||
var err error
|
||||
if len(args) == 0 {
|
||||
message, err = util.InputFromEditor(placeholder, "message.*.md", existingMessage)
|
||||
if err != nil {
|
||||
lmcli.Fatal("Failed to get input: %v\n", err)
|
||||
}
|
||||
} else {
|
||||
message = strings.Join(args, " ")
|
||||
}
|
||||
message = strings.Trim(message, " \t\n")
|
||||
return
|
||||
}
|
@ -1,72 +0,0 @@
|
||||
package cmd
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
cmdutil "git.mlow.ca/mlow/lmcli/pkg/cmd/util"
|
||||
"git.mlow.ca/mlow/lmcli/pkg/lmcli"
|
||||
"git.mlow.ca/mlow/lmcli/pkg/lmcli/model"
|
||||
"github.com/spf13/cobra"
|
||||
)
|
||||
|
||||
func ContinueCmd(ctx *lmcli.Context) *cobra.Command {
|
||||
cmd := &cobra.Command{
|
||||
Use: "continue <conversation>",
|
||||
Short: "Continue a conversation from the last message",
|
||||
Long: `Re-prompt the conversation with all existing prompts. Useful if a reply was cut short.`,
|
||||
Args: func(cmd *cobra.Command, args []string) error {
|
||||
argCount := 1
|
||||
if err := cobra.MinimumNArgs(argCount)(cmd, args); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
},
|
||||
RunE: func(cmd *cobra.Command, args []string) error {
|
||||
shortName := args[0]
|
||||
conversation := cmdutil.LookupConversation(ctx, shortName)
|
||||
|
||||
messages, err := ctx.Store.Messages(conversation)
|
||||
if err != nil {
|
||||
return fmt.Errorf("could not retrieve conversation messages: %v", err)
|
||||
}
|
||||
|
||||
if len(messages) < 2 {
|
||||
return fmt.Errorf("conversation expected to have at least 2 messages")
|
||||
}
|
||||
|
||||
lastMessage := &messages[len(messages)-1]
|
||||
if lastMessage.Role != model.MessageRoleAssistant {
|
||||
return fmt.Errorf("the last message in the conversation is not an assistant message")
|
||||
}
|
||||
|
||||
// Output the contents of the last message so far
|
||||
fmt.Print(lastMessage.Content)
|
||||
|
||||
// Submit the LLM request, allowing it to continue the last message
|
||||
continuedOutput, err := cmdutil.FetchAndShowCompletion(ctx, messages)
|
||||
if err != nil {
|
||||
return fmt.Errorf("error fetching LLM response: %v", err)
|
||||
}
|
||||
|
||||
// Append the new response to the original message
|
||||
lastMessage.Content += strings.TrimRight(continuedOutput[0].Content, "\n\t ")
|
||||
|
||||
// Update the original message
|
||||
err = ctx.Store.UpdateMessage(lastMessage)
|
||||
if err != nil {
|
||||
return fmt.Errorf("could not update the last message: %v", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
},
|
||||
ValidArgsFunction: func(cmd *cobra.Command, args []string, toComplete string) ([]string, cobra.ShellCompDirective) {
|
||||
compMode := cobra.ShellCompDirectiveNoFileComp
|
||||
if len(args) != 0 {
|
||||
return nil, compMode
|
||||
}
|
||||
return ctx.Store.ConversationShortNameCompletions(toComplete), compMode
|
||||
},
|
||||
}
|
||||
return cmd
|
||||
}
|
100
pkg/cmd/edit.go
100
pkg/cmd/edit.go
@ -1,100 +0,0 @@
|
||||
package cmd
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
cmdutil "git.mlow.ca/mlow/lmcli/pkg/cmd/util"
|
||||
"git.mlow.ca/mlow/lmcli/pkg/lmcli"
|
||||
"git.mlow.ca/mlow/lmcli/pkg/lmcli/model"
|
||||
"github.com/spf13/cobra"
|
||||
)
|
||||
|
||||
func EditCmd(ctx *lmcli.Context) *cobra.Command {
|
||||
cmd := &cobra.Command{
|
||||
Use: "edit <conversation>",
|
||||
Short: "Edit the last user reply in a conversation",
|
||||
Args: func(cmd *cobra.Command, args []string) error {
|
||||
argCount := 1
|
||||
if err := cobra.MinimumNArgs(argCount)(cmd, args); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
},
|
||||
RunE: func(cmd *cobra.Command, args []string) error {
|
||||
shortName := args[0]
|
||||
conversation := cmdutil.LookupConversation(ctx, shortName)
|
||||
|
||||
messages, err := ctx.Store.Messages(conversation)
|
||||
if err != nil {
|
||||
return fmt.Errorf("Could not retrieve messages for conversation: %s", conversation.Title)
|
||||
}
|
||||
|
||||
offset, _ := cmd.Flags().GetInt("offset")
|
||||
if offset < 0 {
|
||||
offset = -offset
|
||||
}
|
||||
|
||||
if offset > len(messages)-1 {
|
||||
return fmt.Errorf("Offset %d is before the start of the conversation.", offset)
|
||||
}
|
||||
|
||||
desiredIdx := len(messages) - 1 - offset
|
||||
|
||||
// walk backwards through the conversation deleting messages until and
|
||||
// including the last user message
|
||||
toRemove := []model.Message{}
|
||||
var toEdit *model.Message
|
||||
for i := len(messages) - 1; i >= 0; i-- {
|
||||
if i == desiredIdx {
|
||||
toEdit = &messages[i]
|
||||
}
|
||||
toRemove = append(toRemove, messages[i])
|
||||
messages = messages[:i]
|
||||
if toEdit != nil {
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
newContents := inputFromArgsOrEditor(args[1:], "# Save when finished editing\n", toEdit.Content)
|
||||
switch newContents {
|
||||
case toEdit.Content:
|
||||
return fmt.Errorf("No edits were made.")
|
||||
case "":
|
||||
return fmt.Errorf("No message was provided.")
|
||||
}
|
||||
|
||||
role, _ := cmd.Flags().GetString("role")
|
||||
if role == "" {
|
||||
role = string(toEdit.Role)
|
||||
} else if role != string(model.MessageRoleUser) && role != string(model.MessageRoleAssistant) {
|
||||
return fmt.Errorf("Invalid role specified. Please use 'user' or 'assistant'.")
|
||||
}
|
||||
|
||||
for _, message := range toRemove {
|
||||
err = ctx.Store.DeleteMessage(&message)
|
||||
if err != nil {
|
||||
lmcli.Warn("Could not delete message: %v\n", err)
|
||||
}
|
||||
}
|
||||
|
||||
cmdutil.HandleConversationReply(ctx, conversation, true, model.Message{
|
||||
ConversationID: conversation.ID,
|
||||
Role: model.MessageRole(role),
|
||||
Content: newContents,
|
||||
})
|
||||
return nil
|
||||
},
|
||||
ValidArgsFunction: func(cmd *cobra.Command, args []string, toComplete string) ([]string, cobra.ShellCompDirective) {
|
||||
compMode := cobra.ShellCompDirectiveNoFileComp
|
||||
if len(args) != 0 {
|
||||
return nil, compMode
|
||||
}
|
||||
return ctx.Store.ConversationShortNameCompletions(toComplete), compMode
|
||||
},
|
||||
}
|
||||
|
||||
cmd.Flags().Int("offset", 1, "Offset from the last message to edit")
|
||||
cmd.Flags().StringP("role", "r", "", "Role of the edited message (user or assistant)")
|
||||
|
||||
return cmd
|
||||
}
|
122
pkg/cmd/list.go
122
pkg/cmd/list.go
@ -1,122 +0,0 @@
|
||||
package cmd
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"slices"
|
||||
"time"
|
||||
|
||||
"git.mlow.ca/mlow/lmcli/pkg/lmcli"
|
||||
"git.mlow.ca/mlow/lmcli/pkg/util"
|
||||
"github.com/spf13/cobra"
|
||||
)
|
||||
|
||||
const (
|
||||
LS_COUNT int = 5
|
||||
)
|
||||
|
||||
func ListCmd(ctx *lmcli.Context) *cobra.Command {
|
||||
cmd := &cobra.Command{
|
||||
Use: "list",
|
||||
Aliases: []string{"ls"},
|
||||
Short: "List conversations",
|
||||
Long: `List conversations in order of recent activity`,
|
||||
RunE: func(cmd *cobra.Command, args []string) error {
|
||||
conversations, err := ctx.Store.Conversations()
|
||||
if err != nil {
|
||||
return fmt.Errorf("Could not fetch conversations: %v", err)
|
||||
}
|
||||
|
||||
type Category struct {
|
||||
name string
|
||||
cutoff time.Duration
|
||||
}
|
||||
|
||||
type ConversationLine struct {
|
||||
timeSinceReply time.Duration
|
||||
formatted string
|
||||
}
|
||||
|
||||
now := time.Now()
|
||||
|
||||
midnight := time.Date(now.Year(), now.Month(), now.Day(), 0, 0, 0, 0, now.Location())
|
||||
monthStart := time.Date(now.Year(), now.Month(), 1, 0, 0, 0, 0, now.Location())
|
||||
dayOfWeek := int(now.Weekday())
|
||||
categories := []Category{
|
||||
{"today", now.Sub(midnight)},
|
||||
{"yesterday", now.Sub(midnight.AddDate(0, 0, -1))},
|
||||
{"this week", now.Sub(midnight.AddDate(0, 0, -dayOfWeek))},
|
||||
{"last week", now.Sub(midnight.AddDate(0, 0, -(dayOfWeek + 7)))},
|
||||
{"this month", now.Sub(monthStart)},
|
||||
{"last month", now.Sub(monthStart.AddDate(0, -1, 0))},
|
||||
{"2 months ago", now.Sub(monthStart.AddDate(0, -2, 0))},
|
||||
{"3 months ago", now.Sub(monthStart.AddDate(0, -3, 0))},
|
||||
{"4 months ago", now.Sub(monthStart.AddDate(0, -4, 0))},
|
||||
{"5 months ago", now.Sub(monthStart.AddDate(0, -5, 0))},
|
||||
{"older", now.Sub(time.Time{})},
|
||||
}
|
||||
categorized := map[string][]ConversationLine{}
|
||||
|
||||
all, _ := cmd.Flags().GetBool("all")
|
||||
|
||||
for _, conversation := range conversations {
|
||||
lastMessage, err := ctx.Store.LastMessage(&conversation)
|
||||
if lastMessage == nil || err != nil {
|
||||
continue
|
||||
}
|
||||
|
||||
messageAge := now.Sub(lastMessage.CreatedAt)
|
||||
|
||||
var category string
|
||||
for _, c := range categories {
|
||||
if messageAge < c.cutoff {
|
||||
category = c.name
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
formatted := fmt.Sprintf(
|
||||
"%s - %s - %s",
|
||||
conversation.ShortName.String,
|
||||
util.HumanTimeElapsedSince(messageAge),
|
||||
conversation.Title,
|
||||
)
|
||||
|
||||
categorized[category] = append(
|
||||
categorized[category],
|
||||
ConversationLine{messageAge, formatted},
|
||||
)
|
||||
}
|
||||
|
||||
count, _ := cmd.Flags().GetInt("count")
|
||||
var conversationsPrinted int
|
||||
outer:
|
||||
for _, category := range categories {
|
||||
conversationLines, ok := categorized[category.name]
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
|
||||
slices.SortFunc(conversationLines, func(a, b ConversationLine) int {
|
||||
return int(a.timeSinceReply - b.timeSinceReply)
|
||||
})
|
||||
|
||||
fmt.Printf("%s:\n", category.name)
|
||||
for _, conv := range conversationLines {
|
||||
if conversationsPrinted >= count && !all {
|
||||
fmt.Printf("%d remaining message(s), use --all to view.\n", len(conversations)-conversationsPrinted)
|
||||
break outer
|
||||
}
|
||||
|
||||
fmt.Printf(" %s\n", conv.formatted)
|
||||
conversationsPrinted++
|
||||
}
|
||||
}
|
||||
return nil
|
||||
},
|
||||
}
|
||||
|
||||
cmd.Flags().Bool("all", false, "Show all conversations")
|
||||
cmd.Flags().Int("count", LS_COUNT, "How many conversations to show")
|
||||
|
||||
return cmd
|
||||
}
|
@ -1,60 +0,0 @@
|
||||
package cmd
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
cmdutil "git.mlow.ca/mlow/lmcli/pkg/cmd/util"
|
||||
"git.mlow.ca/mlow/lmcli/pkg/lmcli"
|
||||
"git.mlow.ca/mlow/lmcli/pkg/lmcli/model"
|
||||
"github.com/spf13/cobra"
|
||||
)
|
||||
|
||||
func NewCmd(ctx *lmcli.Context) *cobra.Command {
|
||||
cmd := &cobra.Command{
|
||||
Use: "new [message]",
|
||||
Short: "Start a new conversation",
|
||||
Long: `Start a new conversation with the Large Language Model.`,
|
||||
RunE: func(cmd *cobra.Command, args []string) error {
|
||||
messageContents := inputFromArgsOrEditor(args, "# What would you like to say?\n", "")
|
||||
if messageContents == "" {
|
||||
return fmt.Errorf("No message was provided.")
|
||||
}
|
||||
|
||||
conversation := &model.Conversation{}
|
||||
err := ctx.Store.SaveConversation(conversation)
|
||||
if err != nil {
|
||||
return fmt.Errorf("Could not save new conversation: %v", err)
|
||||
}
|
||||
|
||||
messages := []model.Message{
|
||||
{
|
||||
ConversationID: conversation.ID,
|
||||
Role: model.MessageRoleSystem,
|
||||
Content: getSystemPrompt(ctx),
|
||||
},
|
||||
{
|
||||
ConversationID: conversation.ID,
|
||||
Role: model.MessageRoleUser,
|
||||
Content: messageContents,
|
||||
},
|
||||
}
|
||||
|
||||
cmdutil.HandleConversationReply(ctx, conversation, true, messages...)
|
||||
|
||||
title, err := cmdutil.GenerateTitle(ctx, conversation)
|
||||
if err != nil {
|
||||
lmcli.Warn("Could not generate title for conversation: %v\n", err)
|
||||
}
|
||||
|
||||
conversation.Title = title
|
||||
|
||||
err = ctx.Store.SaveConversation(conversation)
|
||||
if err != nil {
|
||||
lmcli.Warn("Could not save conversation after generating title: %v\n", err)
|
||||
}
|
||||
return nil
|
||||
},
|
||||
}
|
||||
|
||||
return cmd
|
||||
}
|
@ -1,42 +0,0 @@
|
||||
package cmd
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
cmdutil "git.mlow.ca/mlow/lmcli/pkg/cmd/util"
|
||||
"git.mlow.ca/mlow/lmcli/pkg/lmcli"
|
||||
"git.mlow.ca/mlow/lmcli/pkg/lmcli/model"
|
||||
"github.com/spf13/cobra"
|
||||
)
|
||||
|
||||
func PromptCmd(ctx *lmcli.Context) *cobra.Command {
|
||||
cmd := &cobra.Command{
|
||||
Use: "prompt [message]",
|
||||
Short: "Do a one-shot prompt",
|
||||
Long: `Prompt the Large Language Model and get a response.`,
|
||||
RunE: func(cmd *cobra.Command, args []string) error {
|
||||
message := inputFromArgsOrEditor(args, "# What would you like to say?\n", "")
|
||||
if message == "" {
|
||||
return fmt.Errorf("No message was provided.")
|
||||
}
|
||||
|
||||
messages := []model.Message{
|
||||
{
|
||||
Role: model.MessageRoleSystem,
|
||||
Content: getSystemPrompt(ctx),
|
||||
},
|
||||
{
|
||||
Role: model.MessageRoleUser,
|
||||
Content: message,
|
||||
},
|
||||
}
|
||||
|
||||
_, err := cmdutil.FetchAndShowCompletion(ctx, messages)
|
||||
if err != nil {
|
||||
return fmt.Errorf("Error fetching LLM response: %v", err)
|
||||
}
|
||||
return nil
|
||||
},
|
||||
}
|
||||
return cmd
|
||||
}
|
@ -1,60 +0,0 @@
|
||||
package cmd
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
cmdutil "git.mlow.ca/mlow/lmcli/pkg/cmd/util"
|
||||
"git.mlow.ca/mlow/lmcli/pkg/lmcli"
|
||||
"git.mlow.ca/mlow/lmcli/pkg/lmcli/model"
|
||||
"github.com/spf13/cobra"
|
||||
)
|
||||
|
||||
func RemoveCmd(ctx *lmcli.Context) *cobra.Command {
|
||||
cmd := &cobra.Command{
|
||||
Use: "rm <conversation>...",
|
||||
Short: "Remove conversations",
|
||||
Long: `Remove conversations by their short names.`,
|
||||
Args: func(cmd *cobra.Command, args []string) error {
|
||||
argCount := 1
|
||||
if err := cobra.MinimumNArgs(argCount)(cmd, args); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
},
|
||||
RunE: func(cmd *cobra.Command, args []string) error {
|
||||
var toRemove []*model.Conversation
|
||||
for _, shortName := range args {
|
||||
conversation := cmdutil.LookupConversation(ctx, shortName)
|
||||
toRemove = append(toRemove, conversation)
|
||||
}
|
||||
var errors []error
|
||||
for _, c := range toRemove {
|
||||
err := ctx.Store.DeleteConversation(c)
|
||||
if err != nil {
|
||||
errors = append(errors, fmt.Errorf("Could not remove conversation %s: %v", c.ShortName.String, err))
|
||||
}
|
||||
}
|
||||
if len(errors) > 0 {
|
||||
return fmt.Errorf("Could not remove some conversations: %v", errors)
|
||||
}
|
||||
return nil
|
||||
},
|
||||
ValidArgsFunction: func(cmd *cobra.Command, args []string, toComplete string) ([]string, cobra.ShellCompDirective) {
|
||||
compMode := cobra.ShellCompDirectiveNoFileComp
|
||||
var completions []string
|
||||
outer:
|
||||
for _, completion := range ctx.Store.ConversationShortNameCompletions(toComplete) {
|
||||
parts := strings.Split(completion, "\t")
|
||||
for _, arg := range args {
|
||||
if parts[0] == arg {
|
||||
continue outer
|
||||
}
|
||||
}
|
||||
completions = append(completions, completion)
|
||||
}
|
||||
return completions, compMode
|
||||
},
|
||||
}
|
||||
return cmd
|
||||
}
|
@ -1,60 +0,0 @@
|
||||
package cmd
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"git.mlow.ca/mlow/lmcli/pkg/lmcli"
|
||||
"github.com/spf13/cobra"
|
||||
cmdutil "git.mlow.ca/mlow/lmcli/pkg/cmd/util"
|
||||
)
|
||||
|
||||
func RenameCmd(ctx *lmcli.Context) *cobra.Command {
|
||||
cmd := &cobra.Command{
|
||||
Use: "rename <conversation> [title]",
|
||||
Short: "Rename a conversation",
|
||||
Long: `Renames a conversation, either with the provided title or by generating a new name.`,
|
||||
Args: func(cmd *cobra.Command, args []string) error {
|
||||
argCount := 1
|
||||
if err := cobra.MinimumNArgs(argCount)(cmd, args); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
},
|
||||
RunE: func(cmd *cobra.Command, args []string) error {
|
||||
shortName := args[0]
|
||||
conversation := cmdutil.LookupConversation(ctx, shortName)
|
||||
var err error
|
||||
|
||||
generate, _ := cmd.Flags().GetBool("generate")
|
||||
var title string
|
||||
if generate {
|
||||
title, err = cmdutil.GenerateTitle(ctx, conversation)
|
||||
if err != nil {
|
||||
return fmt.Errorf("Could not generate conversation title: %v", err)
|
||||
}
|
||||
} else {
|
||||
if len(args) < 2 {
|
||||
return fmt.Errorf("Conversation title not provided.")
|
||||
}
|
||||
title = strings.Join(args[1:], " ")
|
||||
}
|
||||
|
||||
conversation.Title = title
|
||||
err = ctx.Store.SaveConversation(conversation)
|
||||
if err != nil {
|
||||
lmcli.Warn("Could not save conversation with new title: %v\n", err)
|
||||
}
|
||||
return nil
|
||||
},
|
||||
ValidArgsFunction: func(cmd *cobra.Command, args []string, toComplete string) ([]string, cobra.ShellCompDirective) {
|
||||
compMode := cobra.ShellCompDirectiveNoFileComp
|
||||
if len(args) != 0 {
|
||||
return nil, compMode
|
||||
}
|
||||
return ctx.Store.ConversationShortNameCompletions(toComplete), compMode
|
||||
},
|
||||
}
|
||||
|
||||
return cmd
|
||||
}
|
@ -1,49 +0,0 @@
|
||||
package cmd
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
cmdutil "git.mlow.ca/mlow/lmcli/pkg/cmd/util"
|
||||
"git.mlow.ca/mlow/lmcli/pkg/lmcli"
|
||||
"git.mlow.ca/mlow/lmcli/pkg/lmcli/model"
|
||||
"github.com/spf13/cobra"
|
||||
)
|
||||
|
||||
func ReplyCmd(ctx *lmcli.Context) *cobra.Command {
|
||||
cmd := &cobra.Command{
|
||||
Use: "reply <conversation> [message]",
|
||||
Short: "Reply to a conversation",
|
||||
Long: `Sends a reply to conversation and writes the response to stdout.`,
|
||||
Args: func(cmd *cobra.Command, args []string) error {
|
||||
argCount := 1
|
||||
if err := cobra.MinimumNArgs(argCount)(cmd, args); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
},
|
||||
RunE: func(cmd *cobra.Command, args []string) error {
|
||||
shortName := args[0]
|
||||
conversation := cmdutil.LookupConversation(ctx, shortName)
|
||||
|
||||
reply := inputFromArgsOrEditor(args[1:], "# How would you like to reply?\n", "")
|
||||
if reply == "" {
|
||||
return fmt.Errorf("No reply was provided.")
|
||||
}
|
||||
|
||||
cmdutil.HandleConversationReply(ctx, conversation, true, model.Message{
|
||||
ConversationID: conversation.ID,
|
||||
Role: model.MessageRoleUser,
|
||||
Content: reply,
|
||||
})
|
||||
return nil
|
||||
},
|
||||
ValidArgsFunction: func(cmd *cobra.Command, args []string, toComplete string) ([]string, cobra.ShellCompDirective) {
|
||||
compMode := cobra.ShellCompDirectiveNoFileComp
|
||||
if len(args) != 0 {
|
||||
return nil, compMode
|
||||
}
|
||||
return ctx.Store.ConversationShortNameCompletions(toComplete), compMode
|
||||
},
|
||||
}
|
||||
return cmd
|
||||
}
|
@ -1,58 +0,0 @@
|
||||
package cmd
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
cmdutil "git.mlow.ca/mlow/lmcli/pkg/cmd/util"
|
||||
"git.mlow.ca/mlow/lmcli/pkg/lmcli"
|
||||
"git.mlow.ca/mlow/lmcli/pkg/lmcli/model"
|
||||
"github.com/spf13/cobra"
|
||||
)
|
||||
|
||||
func RetryCmd(ctx *lmcli.Context) *cobra.Command {
|
||||
cmd := &cobra.Command{
|
||||
Use: "retry <conversation>",
|
||||
Short: "Retry the last user reply in a conversation",
|
||||
Long: `Re-prompt the conversation up to the last user response. Can be used to regenerate the last assistant reply, or simply generate one if an error occurred.`,
|
||||
Args: func(cmd *cobra.Command, args []string) error {
|
||||
argCount := 1
|
||||
if err := cobra.MinimumNArgs(argCount)(cmd, args); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
},
|
||||
RunE: func(cmd *cobra.Command, args []string) error {
|
||||
shortName := args[0]
|
||||
conversation := cmdutil.LookupConversation(ctx, shortName)
|
||||
|
||||
messages, err := ctx.Store.Messages(conversation)
|
||||
if err != nil {
|
||||
return fmt.Errorf("Could not retrieve messages for conversation: %s", conversation.Title)
|
||||
}
|
||||
|
||||
// walk backwards through the conversation and delete messages, break
|
||||
// when we find the latest user response
|
||||
for i := len(messages) - 1; i >= 0; i-- {
|
||||
if messages[i].Role == model.MessageRoleUser {
|
||||
break
|
||||
}
|
||||
|
||||
err = ctx.Store.DeleteMessage(&messages[i])
|
||||
if err != nil {
|
||||
lmcli.Warn("Could not delete previous reply: %v\n", err)
|
||||
}
|
||||
}
|
||||
|
||||
cmdutil.HandleConversationReply(ctx, conversation, true)
|
||||
return nil
|
||||
},
|
||||
ValidArgsFunction: func(cmd *cobra.Command, args []string, toComplete string) ([]string, cobra.ShellCompDirective) {
|
||||
compMode := cobra.ShellCompDirectiveNoFileComp
|
||||
if len(args) != 0 {
|
||||
return nil, compMode
|
||||
}
|
||||
return ctx.Store.ConversationShortNameCompletions(toComplete), compMode
|
||||
},
|
||||
}
|
||||
return cmd
|
||||
}
|
@ -1,284 +0,0 @@
|
||||
package util
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"io"
|
||||
"os"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"git.mlow.ca/mlow/lmcli/pkg/lmcli"
|
||||
"git.mlow.ca/mlow/lmcli/pkg/lmcli/model"
|
||||
"git.mlow.ca/mlow/lmcli/pkg/lmcli/tools"
|
||||
"git.mlow.ca/mlow/lmcli/pkg/util"
|
||||
"github.com/alecthomas/chroma/v2/quick"
|
||||
"github.com/charmbracelet/lipgloss"
|
||||
)
|
||||
|
||||
// fetchAndShowCompletion prompts the LLM with the given messages and streams
|
||||
// the response to stdout. Returns all model reply messages.
|
||||
func FetchAndShowCompletion(ctx *lmcli.Context, messages []model.Message) ([]model.Message, error) {
|
||||
content := make(chan string) // receives the reponse from LLM
|
||||
defer close(content)
|
||||
|
||||
// render all content received over the channel
|
||||
go ShowDelayedContent(content)
|
||||
|
||||
completionProvider, err := ctx.GetCompletionProvider(*ctx.Config.Defaults.Model)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var toolBag []model.Tool
|
||||
for _, toolName := range *ctx.Config.Tools.EnabledTools {
|
||||
tool, ok := tools.AvailableTools[toolName]
|
||||
if ok {
|
||||
toolBag = append(toolBag, tool)
|
||||
}
|
||||
}
|
||||
|
||||
requestParams := model.RequestParameters{
|
||||
Model: *ctx.Config.Defaults.Model,
|
||||
MaxTokens: *ctx.Config.Defaults.MaxTokens,
|
||||
Temperature: *ctx.Config.Defaults.Temperature,
|
||||
ToolBag: toolBag,
|
||||
}
|
||||
|
||||
var apiReplies []model.Message
|
||||
response, err := completionProvider.CreateChatCompletionStream(
|
||||
requestParams, messages, &apiReplies, content,
|
||||
)
|
||||
if response != "" {
|
||||
// there was some content, so break to a new line after it
|
||||
fmt.Println()
|
||||
|
||||
if err != nil {
|
||||
lmcli.Warn("Received partial response. Error: %v\n", err)
|
||||
err = nil
|
||||
}
|
||||
}
|
||||
|
||||
return apiReplies, err
|
||||
}
|
||||
|
||||
// lookupConversation either returns the conversation found by the
|
||||
// short name or exits the program
|
||||
func LookupConversation(ctx *lmcli.Context, shortName string) *model.Conversation {
|
||||
c, err := ctx.Store.ConversationByShortName(shortName)
|
||||
if err != nil {
|
||||
lmcli.Fatal("Could not lookup conversation: %v\n", err)
|
||||
}
|
||||
if c.ID == 0 {
|
||||
lmcli.Fatal("Conversation not found with short name: %s\n", shortName)
|
||||
}
|
||||
return c
|
||||
}
|
||||
|
||||
func LookupConversationE(ctx *lmcli.Context, shortName string) (*model.Conversation, error) {
|
||||
c, err := ctx.Store.ConversationByShortName(shortName)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("Could not lookup conversation: %v", err)
|
||||
}
|
||||
if c.ID == 0 {
|
||||
return nil, fmt.Errorf("Conversation not found with short name: %s", shortName)
|
||||
}
|
||||
return c, nil
|
||||
}
|
||||
|
||||
// handleConversationReply handles sending messages to an existing
|
||||
// conversation, optionally persisting both the sent replies and responses.
|
||||
func HandleConversationReply(ctx *lmcli.Context, c *model.Conversation, persist bool, toSend ...model.Message) {
|
||||
existing, err := ctx.Store.Messages(c)
|
||||
if err != nil {
|
||||
lmcli.Fatal("Could not retrieve messages for conversation: %s\n", c.Title)
|
||||
}
|
||||
|
||||
if persist {
|
||||
for _, message := range toSend {
|
||||
err = ctx.Store.SaveMessage(&message)
|
||||
if err != nil {
|
||||
lmcli.Warn("Could not save %s message: %v\n", message.Role, err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
allMessages := append(existing, toSend...)
|
||||
|
||||
RenderConversation(ctx, allMessages, true)
|
||||
|
||||
// render a message header with no contents
|
||||
RenderMessage(ctx, (&model.Message{Role: model.MessageRoleAssistant}))
|
||||
|
||||
replies, err := FetchAndShowCompletion(ctx, allMessages)
|
||||
if err != nil {
|
||||
lmcli.Fatal("Error fetching LLM response: %v\n", err)
|
||||
}
|
||||
|
||||
if persist {
|
||||
for _, reply := range replies {
|
||||
reply.ConversationID = c.ID
|
||||
|
||||
err = ctx.Store.SaveMessage(&reply)
|
||||
if err != nil {
|
||||
lmcli.Warn("Could not save reply: %v\n", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func FormatForExternalPrompt(messages []model.Message, system bool) string {
|
||||
sb := strings.Builder{}
|
||||
for _, message := range messages {
|
||||
if message.Role != model.MessageRoleUser && (message.Role != model.MessageRoleSystem || !system) {
|
||||
continue
|
||||
}
|
||||
sb.WriteString(fmt.Sprintf("<%s>\n", message.Role.FriendlyRole()))
|
||||
sb.WriteString(fmt.Sprintf("\"\"\"\n%s\n\"\"\"\n\n", message.Content))
|
||||
}
|
||||
return sb.String()
|
||||
}
|
||||
|
||||
func GenerateTitle(ctx *lmcli.Context, c *model.Conversation) (string, error) {
|
||||
messages, err := ctx.Store.Messages(c)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
const header = "Generate a concise 4-5 word title for the conversation below."
|
||||
prompt := fmt.Sprintf("%s\n\n---\n\n%s", header, FormatForExternalPrompt(messages, false))
|
||||
|
||||
generateRequest := []model.Message{
|
||||
{
|
||||
Role: model.MessageRoleUser,
|
||||
Content: prompt,
|
||||
},
|
||||
}
|
||||
|
||||
completionProvider, err := ctx.GetCompletionProvider(*ctx.Config.Conversations.TitleGenerationModel)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
requestParams := model.RequestParameters{
|
||||
Model: *ctx.Config.Conversations.TitleGenerationModel,
|
||||
MaxTokens: 25,
|
||||
}
|
||||
|
||||
response, err := completionProvider.CreateChatCompletion(requestParams, generateRequest, nil)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
return response, nil
|
||||
}
|
||||
|
||||
// ShowWaitAnimation prints an animated ellipses to stdout until something is
|
||||
// received on the signal channel. An empty string sent to the channel to
|
||||
// noftify the caller that the animation has completed (carriage returned).
|
||||
func ShowWaitAnimation(signal chan any) {
|
||||
// Save the current cursor position
|
||||
fmt.Print("\033[s")
|
||||
|
||||
animationStep := 0
|
||||
for {
|
||||
select {
|
||||
case _ = <-signal:
|
||||
// Relmcli the cursor position
|
||||
fmt.Print("\033[u")
|
||||
signal <- ""
|
||||
return
|
||||
default:
|
||||
// Move the cursor to the saved position
|
||||
modSix := animationStep % 6
|
||||
if modSix == 3 || modSix == 0 {
|
||||
fmt.Print("\033[u")
|
||||
}
|
||||
if modSix < 3 {
|
||||
fmt.Print(".")
|
||||
} else {
|
||||
fmt.Print(" ")
|
||||
}
|
||||
animationStep++
|
||||
time.Sleep(250 * time.Millisecond)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// ShowDelayedContent displays a waiting animation to stdout while waiting
|
||||
// for content to be received on the provided channel. As soon as any (possibly
|
||||
// chunked) content is received on the channel, the waiting animation is
|
||||
// replaced by the content.
|
||||
// Blocks until the channel is closed.
|
||||
func ShowDelayedContent(content <-chan string) {
|
||||
waitSignal := make(chan any)
|
||||
go ShowWaitAnimation(waitSignal)
|
||||
|
||||
firstChunk := true
|
||||
for chunk := range content {
|
||||
if firstChunk {
|
||||
// notify wait animation that we've received data
|
||||
waitSignal <- ""
|
||||
// wait for signal that wait animation has completed
|
||||
<-waitSignal
|
||||
firstChunk = false
|
||||
}
|
||||
fmt.Print(chunk)
|
||||
}
|
||||
}
|
||||
|
||||
// RenderConversation renders the given messages to TTY, with optional space
|
||||
// for a subsequent message. spaceForResponse controls how many '\n' characters
|
||||
// are printed immediately after the final message (1 if false, 2 if true)
|
||||
func RenderConversation(ctx *lmcli.Context, messages []model.Message, spaceForResponse bool) {
|
||||
l := len(messages)
|
||||
for i, message := range messages {
|
||||
RenderMessage(ctx, &message)
|
||||
if i < l-1 || spaceForResponse {
|
||||
// print an additional space before the next message
|
||||
fmt.Println()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// HighlightMarkdown applies syntax highlighting to the provided markdown text
|
||||
// and writes it to stdout.
|
||||
func HighlightMarkdown(w io.Writer, markdownText string, formatter string, style string) error {
|
||||
return quick.Highlight(w, markdownText, "md", formatter, style)
|
||||
}
|
||||
|
||||
func RenderMessage(ctx *lmcli.Context, m *model.Message) {
|
||||
var messageAge string
|
||||
if m.CreatedAt.IsZero() {
|
||||
messageAge = "now"
|
||||
} else {
|
||||
now := time.Now()
|
||||
messageAge = util.HumanTimeElapsedSince(now.Sub(m.CreatedAt))
|
||||
}
|
||||
|
||||
headerStyle := lipgloss.NewStyle().Bold(true)
|
||||
|
||||
switch m.Role {
|
||||
case model.MessageRoleSystem:
|
||||
headerStyle = headerStyle.Foreground(lipgloss.Color("9")) // bright red
|
||||
case model.MessageRoleUser:
|
||||
headerStyle = headerStyle.Foreground(lipgloss.Color("10")) // bright green
|
||||
case model.MessageRoleAssistant:
|
||||
headerStyle = headerStyle.Foreground(lipgloss.Color("12")) // bright blue
|
||||
}
|
||||
|
||||
role := headerStyle.Render(m.Role.FriendlyRole())
|
||||
|
||||
separatorStyle := lipgloss.NewStyle().Foreground(lipgloss.Color("3"))
|
||||
separator := separatorStyle.Render("===")
|
||||
timestamp := separatorStyle.Render(messageAge)
|
||||
|
||||
fmt.Printf("%s %s - %s %s\n\n", separator, role, timestamp, separator)
|
||||
if m.Content != "" {
|
||||
HighlightMarkdown(
|
||||
os.Stdout, m.Content,
|
||||
*ctx.Config.Chroma.Formatter,
|
||||
*ctx.Config.Chroma.Style,
|
||||
)
|
||||
fmt.Println()
|
||||
}
|
||||
}
|
@ -1,45 +0,0 @@
|
||||
package cmd
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
cmdutil "git.mlow.ca/mlow/lmcli/pkg/cmd/util"
|
||||
"git.mlow.ca/mlow/lmcli/pkg/lmcli"
|
||||
"github.com/spf13/cobra"
|
||||
)
|
||||
|
||||
func ViewCmd(ctx *lmcli.Context) *cobra.Command {
|
||||
cmd := &cobra.Command{
|
||||
Use: "view <conversation>",
|
||||
Short: "View messages in a conversation",
|
||||
Long: `Finds a conversation by its short name and displays its contents.`,
|
||||
Args: func(cmd *cobra.Command, args []string) error {
|
||||
argCount := 1
|
||||
if err := cobra.MinimumNArgs(argCount)(cmd, args); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
},
|
||||
RunE: func(cmd *cobra.Command, args []string) error {
|
||||
shortName := args[0]
|
||||
conversation := cmdutil.LookupConversation(ctx, shortName)
|
||||
|
||||
messages, err := ctx.Store.Messages(conversation)
|
||||
if err != nil {
|
||||
return fmt.Errorf("Could not retrieve messages for conversation: %s", conversation.Title)
|
||||
}
|
||||
|
||||
cmdutil.RenderConversation(ctx, messages, false)
|
||||
return nil
|
||||
},
|
||||
ValidArgsFunction: func(cmd *cobra.Command, args []string, toComplete string) ([]string, cobra.ShellCompDirective) {
|
||||
compMode := cobra.ShellCompDirectiveNoFileComp
|
||||
if len(args) != 0 {
|
||||
return nil, compMode
|
||||
}
|
||||
return ctx.Store.ConversationShortNameCompletions(toComplete), compMode
|
||||
},
|
||||
}
|
||||
|
||||
return cmd
|
||||
}
|
@ -1,97 +0,0 @@
|
||||
package lmcli
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
|
||||
"git.mlow.ca/mlow/lmcli/pkg/lmcli/provider"
|
||||
"git.mlow.ca/mlow/lmcli/pkg/lmcli/provider/anthropic"
|
||||
"git.mlow.ca/mlow/lmcli/pkg/lmcli/provider/openai"
|
||||
"gorm.io/driver/sqlite"
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
type Context struct {
|
||||
Config Config
|
||||
Store ConversationStore
|
||||
}
|
||||
|
||||
func NewContext() (*Context, error) {
|
||||
configFile := filepath.Join(configDir(), "config.yaml")
|
||||
config, err := NewConfig(configFile)
|
||||
if err != nil {
|
||||
Fatal("%v\n", err)
|
||||
}
|
||||
|
||||
databaseFile := filepath.Join(dataDir(), "conversations.db")
|
||||
db, err := gorm.Open(sqlite.Open(databaseFile), &gorm.Config{})
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("Error establishing connection to store: %v", err)
|
||||
}
|
||||
s, err := NewSQLStore(db)
|
||||
if err != nil {
|
||||
Fatal("%v\n", err)
|
||||
}
|
||||
|
||||
return &Context{*config, s}, nil
|
||||
}
|
||||
|
||||
func (c *Context) GetCompletionProvider(model string) (provider.ChatCompletionClient, error) {
|
||||
for _, m := range *c.Config.Anthropic.Models {
|
||||
if m == model {
|
||||
anthropic := &anthropic.AnthropicClient{
|
||||
APIKey: *c.Config.Anthropic.APIKey,
|
||||
}
|
||||
return anthropic, nil
|
||||
}
|
||||
}
|
||||
for _, m := range *c.Config.OpenAI.Models {
|
||||
if m == model {
|
||||
openai := &openai.OpenAIClient{
|
||||
APIKey: *c.Config.OpenAI.APIKey,
|
||||
}
|
||||
return openai, nil
|
||||
}
|
||||
}
|
||||
return nil, fmt.Errorf("unknown model: %s", model)
|
||||
}
|
||||
|
||||
func configDir() string {
|
||||
var configDir string
|
||||
|
||||
xdgConfigHome := os.Getenv("XDG_CONFIG_HOME")
|
||||
if xdgConfigHome != "" {
|
||||
configDir = filepath.Join(xdgConfigHome, "lmcli")
|
||||
} else {
|
||||
userHomeDir, _ := os.UserHomeDir()
|
||||
configDir = filepath.Join(userHomeDir, ".config/lmcli")
|
||||
}
|
||||
|
||||
os.MkdirAll(configDir, 0755)
|
||||
return configDir
|
||||
}
|
||||
|
||||
func dataDir() string {
|
||||
var dataDir string
|
||||
|
||||
xdgDataHome := os.Getenv("XDG_DATA_HOME")
|
||||
if xdgDataHome != "" {
|
||||
dataDir = filepath.Join(xdgDataHome, "lmcli")
|
||||
} else {
|
||||
userHomeDir, _ := os.UserHomeDir()
|
||||
dataDir = filepath.Join(userHomeDir, ".local/share/lmcli")
|
||||
}
|
||||
|
||||
os.MkdirAll(dataDir, 0755)
|
||||
return dataDir
|
||||
}
|
||||
|
||||
func Fatal(format string, args ...any) {
|
||||
fmt.Fprintf(os.Stderr, format, args...)
|
||||
os.Exit(1)
|
||||
}
|
||||
|
||||
func Warn(format string, args ...any) {
|
||||
fmt.Fprintf(os.Stderr, format, args...)
|
||||
}
|
@ -1,58 +0,0 @@
|
||||
package model
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"time"
|
||||
)
|
||||
|
||||
type MessageRole string
|
||||
|
||||
const (
|
||||
MessageRoleSystem MessageRole = "system"
|
||||
MessageRoleUser MessageRole = "user"
|
||||
MessageRoleAssistant MessageRole = "assistant"
|
||||
MessageRoleToolCall MessageRole = "tool_call"
|
||||
MessageRoleToolResult MessageRole = "tool_result"
|
||||
)
|
||||
|
||||
type Message struct {
|
||||
ID uint `gorm:"primaryKey"`
|
||||
ConversationID uint `gorm:"foreignKey:ConversationID"`
|
||||
Content string
|
||||
Role MessageRole
|
||||
CreatedAt time.Time
|
||||
ToolCalls ToolCalls // a json array of tool calls (from the modl)
|
||||
ToolResults ToolResults // a json array of tool results
|
||||
}
|
||||
|
||||
type Conversation struct {
|
||||
ID uint `gorm:"primaryKey"`
|
||||
ShortName sql.NullString
|
||||
Title string
|
||||
}
|
||||
|
||||
type RequestParameters struct {
|
||||
Model string
|
||||
MaxTokens int
|
||||
Temperature float32
|
||||
TopP float32
|
||||
|
||||
SystemPrompt string
|
||||
ToolBag []Tool
|
||||
}
|
||||
|
||||
// FriendlyRole returns a human friendly signifier for the message's role.
|
||||
func (m *MessageRole) FriendlyRole() string {
|
||||
var friendlyRole string
|
||||
switch *m {
|
||||
case MessageRoleUser:
|
||||
friendlyRole = "You"
|
||||
case MessageRoleSystem:
|
||||
friendlyRole = "System"
|
||||
case MessageRoleAssistant:
|
||||
friendlyRole = "Assistant"
|
||||
default:
|
||||
friendlyRole = string(*m)
|
||||
}
|
||||
return friendlyRole
|
||||
}
|
@ -1,98 +0,0 @@
|
||||
package model
|
||||
|
||||
import (
|
||||
"database/sql/driver"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
)
|
||||
|
||||
type Tool struct {
|
||||
Name string
|
||||
Description string
|
||||
Parameters []ToolParameter
|
||||
Impl func(*Tool, map[string]interface{}) (string, error)
|
||||
}
|
||||
|
||||
type ToolParameter struct {
|
||||
Name string `json:"name"`
|
||||
Type string `json:"type"` // "string", "integer", "boolean"
|
||||
Required bool `json:"required"`
|
||||
Description string `json:"description"`
|
||||
Enum []string `json:"enum,omitempty"`
|
||||
}
|
||||
|
||||
type ToolCall struct {
|
||||
ID string `json:"id"`
|
||||
Name string `json:"name"`
|
||||
Parameters map[string]interface{} `json:"parameters"`
|
||||
}
|
||||
|
||||
type ToolCalls []ToolCall
|
||||
|
||||
func (tc *ToolCalls) Scan(value any) (err error) {
|
||||
s := value.(string)
|
||||
if value == nil || s == "" {
|
||||
*tc = nil
|
||||
return
|
||||
}
|
||||
err = json.Unmarshal([]byte(s), tc)
|
||||
return
|
||||
}
|
||||
|
||||
func (tc ToolCalls) Value() (driver.Value, error) {
|
||||
if len(tc) == 0 {
|
||||
return "", nil
|
||||
}
|
||||
jsonBytes, err := json.Marshal(tc)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("Could not marshal ToolCalls to JSON: %v\n", err)
|
||||
}
|
||||
return string(jsonBytes), nil
|
||||
}
|
||||
|
||||
type ToolResult struct {
|
||||
ToolCallID string `json:"toolCallID"`
|
||||
ToolName string `json:"toolName,omitempty"`
|
||||
Result string `json:"result,omitempty"`
|
||||
}
|
||||
|
||||
type ToolResults []ToolResult
|
||||
|
||||
func (tr *ToolResults) Scan(value any) (err error) {
|
||||
s := value.(string)
|
||||
if value == nil || s == "" {
|
||||
*tr = nil
|
||||
return
|
||||
}
|
||||
err = json.Unmarshal([]byte(s), tr)
|
||||
return
|
||||
}
|
||||
|
||||
func (tr ToolResults) Value() (driver.Value, error) {
|
||||
if len(tr) == 0 {
|
||||
return "", nil
|
||||
}
|
||||
jsonBytes, err := json.Marshal([]ToolResult(tr))
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("Could not marshal ToolResults to JSON: %v\n", err)
|
||||
}
|
||||
return string(jsonBytes), nil
|
||||
}
|
||||
|
||||
type CallResult struct {
|
||||
Message string `json:"message"`
|
||||
Result any `json:"result,omitempty"`
|
||||
}
|
||||
|
||||
func (r CallResult) ToJson() (string, error) {
|
||||
if r.Message == "" {
|
||||
// When message not supplied, assume success
|
||||
r.Message = "success"
|
||||
}
|
||||
|
||||
jsonBytes, err := json.Marshal(r)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("Could not marshal CallResult to JSON: %v\n", err)
|
||||
}
|
||||
return string(jsonBytes), nil
|
||||
}
|
@ -1,322 +0,0 @@
|
||||
package anthropic
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"encoding/xml"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"strings"
|
||||
|
||||
"git.mlow.ca/mlow/lmcli/pkg/lmcli/tools"
|
||||
"git.mlow.ca/mlow/lmcli/pkg/lmcli/model"
|
||||
)
|
||||
|
||||
type AnthropicClient struct {
|
||||
APIKey string
|
||||
}
|
||||
|
||||
type Message struct {
|
||||
Role string `json:"role"`
|
||||
OriginalContent string `json:"content"`
|
||||
}
|
||||
|
||||
type Request struct {
|
||||
Model string `json:"model"`
|
||||
Messages []Message `json:"messages"`
|
||||
System string `json:"system,omitempty"`
|
||||
MaxTokens int `json:"max_tokens,omitempty"`
|
||||
StopSequences []string `json:"stop_sequences,omitempty"`
|
||||
Stream bool `json:"stream,omitempty"`
|
||||
Temperature float32 `json:"temperature,omitempty"`
|
||||
//TopP float32 `json:"top_p,omitempty"`
|
||||
//TopK float32 `json:"top_k,omitempty"`
|
||||
}
|
||||
|
||||
type OriginalContent struct {
|
||||
Type string `json:"type"`
|
||||
Text string `json:"text"`
|
||||
}
|
||||
|
||||
type Response struct {
|
||||
Id string `json:"id"`
|
||||
Type string `json:"type"`
|
||||
Role string `json:"role"`
|
||||
OriginalContent []OriginalContent `json:"content"`
|
||||
}
|
||||
|
||||
const FUNCTION_STOP_SEQUENCE = "</function_calls>"
|
||||
|
||||
func buildRequest(params model.RequestParameters, messages []model.Message) Request {
|
||||
requestBody := Request{
|
||||
Model: params.Model,
|
||||
Messages: make([]Message, len(messages)),
|
||||
System: params.SystemPrompt,
|
||||
MaxTokens: params.MaxTokens,
|
||||
Temperature: params.Temperature,
|
||||
Stream: false,
|
||||
|
||||
StopSequences: []string{
|
||||
FUNCTION_STOP_SEQUENCE,
|
||||
"\n\nHuman:",
|
||||
},
|
||||
}
|
||||
|
||||
startIdx := 0
|
||||
if messages[0].Role == model.MessageRoleSystem {
|
||||
requestBody.System = messages[0].Content
|
||||
requestBody.Messages = requestBody.Messages[:len(messages)-1]
|
||||
startIdx = 1
|
||||
}
|
||||
|
||||
if len(params.ToolBag) > 0 {
|
||||
if len(requestBody.System) > 0 {
|
||||
// add a divider between existing system prompt and tools
|
||||
requestBody.System += "\n\n---\n\n"
|
||||
}
|
||||
requestBody.System += buildToolsSystemPrompt(params.ToolBag)
|
||||
}
|
||||
|
||||
for i, msg := range messages[startIdx:] {
|
||||
message := &requestBody.Messages[i]
|
||||
|
||||
switch msg.Role {
|
||||
case model.MessageRoleToolCall:
|
||||
message.Role = "assistant"
|
||||
message.OriginalContent = msg.Content
|
||||
//message.ToolCalls = convertToolCallToOpenAI(m.ToolCalls)
|
||||
case model.MessageRoleToolResult:
|
||||
xmlFuncResults := convertToolResultsToXMLFunctionResult(msg.ToolResults)
|
||||
xmlString, err := xmlFuncResults.XMLString()
|
||||
if err != nil {
|
||||
panic("Could not serialize []ToolResult to XMLFunctionResults")
|
||||
}
|
||||
message.Role = "user"
|
||||
message.OriginalContent = xmlString
|
||||
default:
|
||||
message.Role = string(msg.Role)
|
||||
message.OriginalContent = msg.Content
|
||||
}
|
||||
}
|
||||
return requestBody
|
||||
}
|
||||
|
||||
func sendRequest(c *AnthropicClient, r Request) (*http.Response, error) {
|
||||
url := "https://api.anthropic.com/v1/messages"
|
||||
|
||||
jsonBody, err := json.Marshal(r)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to marshal request body: %v", err)
|
||||
}
|
||||
|
||||
req, err := http.NewRequest("POST", url, bytes.NewBuffer(jsonBody))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create HTTP request: %v", err)
|
||||
}
|
||||
|
||||
req.Header.Set("x-api-key", c.APIKey)
|
||||
req.Header.Set("anthropic-version", "2023-06-01")
|
||||
req.Header.Set("content-type", "application/json")
|
||||
|
||||
client := &http.Client{}
|
||||
resp, err := client.Do(req)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to send HTTP request: %v", err)
|
||||
}
|
||||
|
||||
return resp, nil
|
||||
}
|
||||
|
||||
func (c *AnthropicClient) CreateChatCompletion(
|
||||
params model.RequestParameters,
|
||||
messages []model.Message,
|
||||
replies *[]model.Message,
|
||||
) (string, error) {
|
||||
request := buildRequest(params, messages)
|
||||
|
||||
resp, err := sendRequest(c, request)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
var response Response
|
||||
err = json.NewDecoder(resp.Body).Decode(&response)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to decode response: %v", err)
|
||||
}
|
||||
|
||||
sb := strings.Builder{}
|
||||
for _, content := range response.OriginalContent {
|
||||
var reply model.Message
|
||||
switch content.Type {
|
||||
case "text":
|
||||
reply = model.Message{
|
||||
Role: model.MessageRoleAssistant,
|
||||
Content: content.Text,
|
||||
}
|
||||
sb.WriteString(reply.Content)
|
||||
default:
|
||||
return "", fmt.Errorf("unsupported message type: %s", content.Type)
|
||||
}
|
||||
*replies = append(*replies, reply)
|
||||
}
|
||||
|
||||
return sb.String(), nil
|
||||
}
|
||||
|
||||
func (c *AnthropicClient) CreateChatCompletionStream(
|
||||
params model.RequestParameters,
|
||||
messages []model.Message,
|
||||
replies *[]model.Message,
|
||||
output chan<- string,
|
||||
) (string, error) {
|
||||
request := buildRequest(params, messages)
|
||||
request.Stream = true
|
||||
|
||||
resp, err := sendRequest(c, request)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
scanner := bufio.NewScanner(resp.Body)
|
||||
sb := strings.Builder{}
|
||||
|
||||
for scanner.Scan() {
|
||||
line := scanner.Text()
|
||||
line = strings.TrimSpace(line)
|
||||
|
||||
if len(line) == 0 {
|
||||
continue
|
||||
}
|
||||
|
||||
if line[0] == '{' {
|
||||
var event map[string]interface{}
|
||||
err := json.Unmarshal([]byte(line), &event)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to unmarshal event data '%s': %v", line, err)
|
||||
}
|
||||
eventType, ok := event["type"].(string)
|
||||
if !ok {
|
||||
return "", fmt.Errorf("invalid event: %s", line)
|
||||
}
|
||||
switch eventType {
|
||||
case "error":
|
||||
return sb.String(), fmt.Errorf("an error occurred: %s", event["error"])
|
||||
default:
|
||||
return sb.String(), fmt.Errorf("unknown event type: %s", eventType)
|
||||
}
|
||||
} else if strings.HasPrefix(line, "data:") {
|
||||
data := strings.TrimSpace(strings.TrimPrefix(line, "data:"))
|
||||
var event map[string]interface{}
|
||||
err := json.Unmarshal([]byte(data), &event)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to unmarshal event data: %v", err)
|
||||
}
|
||||
|
||||
eventType, ok := event["type"].(string)
|
||||
if !ok {
|
||||
return "", fmt.Errorf("invalid event type")
|
||||
}
|
||||
|
||||
switch eventType {
|
||||
case "message_start":
|
||||
// noop
|
||||
case "ping":
|
||||
// write an empty string to signal start of text
|
||||
output <- ""
|
||||
case "content_block_start":
|
||||
// ignore?
|
||||
case "content_block_delta":
|
||||
delta, ok := event["delta"].(map[string]interface{})
|
||||
if !ok {
|
||||
return "", fmt.Errorf("invalid content block delta")
|
||||
}
|
||||
text, ok := delta["text"].(string)
|
||||
if !ok {
|
||||
return "", fmt.Errorf("invalid text delta")
|
||||
}
|
||||
sb.WriteString(text)
|
||||
output <- text
|
||||
case "content_block_stop":
|
||||
// ignore?
|
||||
case "message_delta":
|
||||
delta, ok := event["delta"].(map[string]interface{})
|
||||
if !ok {
|
||||
return "", fmt.Errorf("invalid message delta")
|
||||
}
|
||||
stopReason, ok := delta["stop_reason"].(string)
|
||||
if ok && stopReason == "stop_sequence" {
|
||||
stopSequence, ok := delta["stop_sequence"].(string)
|
||||
if ok && stopSequence == FUNCTION_STOP_SEQUENCE {
|
||||
content := sb.String()
|
||||
|
||||
start := strings.Index(content, "<function_calls>")
|
||||
if start == -1 {
|
||||
return content, fmt.Errorf("reached </function_calls> stop sequence but no opening tag found")
|
||||
}
|
||||
|
||||
funcCallXml := content[start:]
|
||||
funcCallXml += FUNCTION_STOP_SEQUENCE
|
||||
|
||||
sb.WriteString(FUNCTION_STOP_SEQUENCE)
|
||||
output <- FUNCTION_STOP_SEQUENCE
|
||||
|
||||
// Extract function calls
|
||||
var functionCalls XMLFunctionCalls
|
||||
err := xml.Unmarshal([]byte(sb.String()), &functionCalls)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to unmarshal function_calls: %v", err)
|
||||
}
|
||||
|
||||
// Execute function calls
|
||||
toolCall := model.Message{
|
||||
Role: model.MessageRoleToolCall,
|
||||
Content: content,
|
||||
ToolCalls: convertXMLFunctionCallsToToolCalls(functionCalls),
|
||||
}
|
||||
|
||||
toolResults, err := tools.ExecuteToolCalls(toolCall.ToolCalls, params.ToolBag)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
toolReply := model.Message{
|
||||
Role: model.MessageRoleToolResult,
|
||||
ToolResults: toolResults,
|
||||
}
|
||||
|
||||
if replies != nil {
|
||||
*replies = append(append(*replies, toolCall), toolReply)
|
||||
}
|
||||
|
||||
// Recurse into CreateChatCompletionStream with the tool call replies
|
||||
// added to the original messages
|
||||
messages = append(append(messages, toolCall), toolReply)
|
||||
return c.CreateChatCompletionStream(params, messages, replies, output)
|
||||
}
|
||||
}
|
||||
case "message_stop":
|
||||
// return the completed message
|
||||
reply := model.Message{
|
||||
Role: model.MessageRoleAssistant,
|
||||
Content: sb.String(),
|
||||
}
|
||||
*replies = append(*replies, reply)
|
||||
return sb.String(), nil
|
||||
case "error":
|
||||
return sb.String(), fmt.Errorf("an error occurred: %s", event["error"])
|
||||
default:
|
||||
fmt.Printf("\nUnrecognized event: %s\n", data)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if err := scanner.Err(); err != nil {
|
||||
return "", fmt.Errorf("failed to read response body: %v", err)
|
||||
}
|
||||
|
||||
return "", fmt.Errorf("unexpected end of stream")
|
||||
}
|
@ -1,182 +0,0 @@
|
||||
package anthropic
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"strings"
|
||||
"text/template"
|
||||
|
||||
"git.mlow.ca/mlow/lmcli/pkg/lmcli/model"
|
||||
)
|
||||
|
||||
const TOOL_PREAMBLE = `In this environment you have access to a set of tools which may assist you in fulfilling user requests.
|
||||
|
||||
You may call them like this:
|
||||
<function_calls>
|
||||
<invoke>
|
||||
<tool_name>$TOOL_NAME</tool_name>
|
||||
<parameters>
|
||||
<$PARAMETER_NAME>$PARAMETER_VALUE</$PARAMETER_NAME>
|
||||
...
|
||||
</parameters>
|
||||
</invoke>
|
||||
</function_calls>
|
||||
|
||||
Here are the tools available:`
|
||||
|
||||
type XMLTools struct {
|
||||
XMLName struct{} `xml:"tools"`
|
||||
ToolDescriptions []XMLToolDescription `xml:"tool_description"`
|
||||
}
|
||||
|
||||
type XMLToolDescription struct {
|
||||
ToolName string `xml:"tool_name"`
|
||||
Description string `xml:"description"`
|
||||
Parameters []XMLToolParameter `xml:"parameters>parameter"`
|
||||
}
|
||||
|
||||
type XMLToolParameter struct {
|
||||
Name string `xml:"name"`
|
||||
Type string `xml:"type"`
|
||||
Description string `xml:"description"`
|
||||
}
|
||||
|
||||
type XMLFunctionCalls struct {
|
||||
XMLName struct{} `xml:"function_calls"`
|
||||
Invoke []XMLFunctionInvoke `xml:"invoke"`
|
||||
}
|
||||
|
||||
type XMLFunctionInvoke struct {
|
||||
ToolName string `xml:"tool_name"`
|
||||
Parameters XMLFunctionInvokeParameters `xml:"parameters"`
|
||||
}
|
||||
|
||||
type XMLFunctionInvokeParameters struct {
|
||||
String string `xml:",innerxml"`
|
||||
}
|
||||
|
||||
type XMLFunctionResults struct {
|
||||
XMLName struct{} `xml:"function_results"`
|
||||
Result []XMLFunctionResult `xml:"result"`
|
||||
}
|
||||
|
||||
type XMLFunctionResult struct {
|
||||
ToolName string `xml:"tool_name"`
|
||||
Stdout string `xml:"stdout"`
|
||||
}
|
||||
|
||||
// accepts raw XML from XMLFunctionInvokeParameters.String, returns map of
|
||||
// parameters name to value
|
||||
func parseFunctionParametersXML(params string) map[string]interface{} {
|
||||
lines := strings.Split(params, "\n")
|
||||
ret := make(map[string]interface{}, len(lines))
|
||||
for _, line := range lines {
|
||||
i := strings.Index(line, ">")
|
||||
if i == -1 {
|
||||
continue
|
||||
}
|
||||
j := strings.Index(line, "</")
|
||||
if j == -1 {
|
||||
continue
|
||||
}
|
||||
// chop from after opening < to first > to get parameter name,
|
||||
// then chop after > to first </ to get parameter value
|
||||
ret[line[1:i]] = line[i+1 : j]
|
||||
}
|
||||
return ret
|
||||
}
|
||||
|
||||
func convertToolsToXMLTools(tools []model.Tool) XMLTools {
|
||||
converted := make([]XMLToolDescription, len(tools))
|
||||
for i, tool := range tools {
|
||||
converted[i].ToolName = tool.Name
|
||||
converted[i].Description = tool.Description
|
||||
|
||||
params := make([]XMLToolParameter, len(tool.Parameters))
|
||||
for j, param := range tool.Parameters {
|
||||
params[j].Name = param.Name
|
||||
params[j].Description = param.Description
|
||||
params[j].Type = param.Type
|
||||
}
|
||||
|
||||
converted[i].Parameters = params
|
||||
}
|
||||
return XMLTools{
|
||||
ToolDescriptions: converted,
|
||||
}
|
||||
}
|
||||
|
||||
func convertXMLFunctionCallsToToolCalls(functionCalls XMLFunctionCalls) []model.ToolCall {
|
||||
toolCalls := make([]model.ToolCall, len(functionCalls.Invoke))
|
||||
for i, invoke := range functionCalls.Invoke {
|
||||
toolCalls[i].Name = invoke.ToolName
|
||||
toolCalls[i].Parameters = parseFunctionParametersXML(invoke.Parameters.String)
|
||||
}
|
||||
return toolCalls
|
||||
}
|
||||
|
||||
func convertToolResultsToXMLFunctionResult(toolResults []model.ToolResult) XMLFunctionResults {
|
||||
converted := make([]XMLFunctionResult, len(toolResults))
|
||||
for i, result := range toolResults {
|
||||
converted[i].ToolName = result.ToolName
|
||||
converted[i].Stdout = result.Result
|
||||
}
|
||||
return XMLFunctionResults{
|
||||
Result: converted,
|
||||
}
|
||||
}
|
||||
|
||||
func buildToolsSystemPrompt(tools []model.Tool) string {
|
||||
xmlTools := convertToolsToXMLTools(tools)
|
||||
xmlToolsString, err := xmlTools.XMLString()
|
||||
if err != nil {
|
||||
panic("Could not serialize []model.Tool to XMLTools")
|
||||
}
|
||||
return TOOL_PREAMBLE + "\n" + xmlToolsString + "\n"
|
||||
}
|
||||
|
||||
func (x XMLTools) XMLString() (string, error) {
|
||||
tmpl, err := template.New("tools").Parse(`<tools>
|
||||
{{range .ToolDescriptions}}<tool_description>
|
||||
<tool_name>{{.ToolName}}</tool_name>
|
||||
<description>
|
||||
{{.Description}}
|
||||
</description>
|
||||
<parameters>
|
||||
{{range .Parameters}}<parameter>
|
||||
<name>{{.Name}}</name>
|
||||
<type>{{.Type}}</type>
|
||||
<description>{{.Description}}</description>
|
||||
</parameter>
|
||||
{{end}}</parameters>
|
||||
</tool_description>
|
||||
{{end}}</tools>`)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
var buf bytes.Buffer
|
||||
if err := tmpl.Execute(&buf, x); err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
return buf.String(), nil
|
||||
}
|
||||
|
||||
func (x XMLFunctionResults) XMLString() (string, error) {
|
||||
tmpl, err := template.New("function_results").Parse(`<function_results>
|
||||
{{range .Result}}<result>
|
||||
<tool_name>{{.ToolName}}</tool_name>
|
||||
<stdout>{{.Stdout}}</stdout>
|
||||
</result>
|
||||
{{end}}</function_results>`)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
var buf bytes.Buffer
|
||||
if err := tmpl.Execute(&buf, x); err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
return buf.String(), nil
|
||||
}
|
@ -1,270 +0,0 @@
|
||||
package openai
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"strings"
|
||||
|
||||
"git.mlow.ca/mlow/lmcli/pkg/lmcli/model"
|
||||
"git.mlow.ca/mlow/lmcli/pkg/lmcli/tools"
|
||||
openai "github.com/sashabaranov/go-openai"
|
||||
)
|
||||
|
||||
type OpenAIClient struct {
|
||||
APIKey string
|
||||
}
|
||||
|
||||
type OpenAIToolParameters struct {
|
||||
Type string `json:"type"`
|
||||
Properties map[string]OpenAIToolParameter `json:"properties,omitempty"`
|
||||
Required []string `json:"required,omitempty"`
|
||||
}
|
||||
|
||||
type OpenAIToolParameter struct {
|
||||
Type string `json:"type"`
|
||||
Description string `json:"description"`
|
||||
Enum []string `json:"enum,omitempty"`
|
||||
}
|
||||
|
||||
func convertTools(tools []model.Tool) []openai.Tool {
|
||||
openaiTools := make([]openai.Tool, len(tools))
|
||||
for i, tool := range tools {
|
||||
openaiTools[i].Type = "function"
|
||||
|
||||
params := make(map[string]OpenAIToolParameter)
|
||||
var required []string
|
||||
|
||||
for _, param := range tool.Parameters {
|
||||
params[param.Name] = OpenAIToolParameter{
|
||||
Type: param.Type,
|
||||
Description: param.Description,
|
||||
Enum: param.Enum,
|
||||
}
|
||||
if param.Required {
|
||||
required = append(required, param.Name)
|
||||
}
|
||||
}
|
||||
|
||||
openaiTools[i].Function = openai.FunctionDefinition{
|
||||
Name: tool.Name,
|
||||
Description: tool.Description,
|
||||
Parameters: OpenAIToolParameters{
|
||||
Type: "object",
|
||||
Properties: params,
|
||||
Required: required,
|
||||
},
|
||||
}
|
||||
}
|
||||
return openaiTools
|
||||
}
|
||||
|
||||
func convertToolCallToOpenAI(toolCalls []model.ToolCall) []openai.ToolCall {
|
||||
converted := make([]openai.ToolCall, len(toolCalls))
|
||||
for i, call := range toolCalls {
|
||||
converted[i].Type = "function"
|
||||
converted[i].ID = call.ID
|
||||
converted[i].Function.Name = call.Name
|
||||
|
||||
json, _ := json.Marshal(call.Parameters)
|
||||
converted[i].Function.Arguments = string(json)
|
||||
}
|
||||
return converted
|
||||
}
|
||||
|
||||
func convertToolCallToAPI(toolCalls []openai.ToolCall) []model.ToolCall {
|
||||
converted := make([]model.ToolCall, len(toolCalls))
|
||||
for i, call := range toolCalls {
|
||||
converted[i].ID = call.ID
|
||||
converted[i].Name = call.Function.Name
|
||||
json.Unmarshal([]byte(call.Function.Arguments), &converted[i].Parameters)
|
||||
}
|
||||
return converted
|
||||
}
|
||||
|
||||
func createChatCompletionRequest(
|
||||
c *OpenAIClient,
|
||||
params model.RequestParameters,
|
||||
messages []model.Message,
|
||||
) openai.ChatCompletionRequest {
|
||||
requestMessages := make([]openai.ChatCompletionMessage, 0, len(messages))
|
||||
|
||||
for _, m := range messages {
|
||||
switch m.Role {
|
||||
case "tool_call":
|
||||
message := openai.ChatCompletionMessage{}
|
||||
message.Role = "assistant"
|
||||
message.Content = m.Content
|
||||
message.ToolCalls = convertToolCallToOpenAI(m.ToolCalls)
|
||||
requestMessages = append(requestMessages, message)
|
||||
case "tool_result":
|
||||
// expand tool_result messages' results into multiple openAI messages
|
||||
for _, result := range m.ToolResults {
|
||||
message := openai.ChatCompletionMessage{}
|
||||
message.Role = "tool"
|
||||
message.Content = result.Result
|
||||
message.ToolCallID = result.ToolCallID
|
||||
requestMessages = append(requestMessages, message)
|
||||
}
|
||||
default:
|
||||
message := openai.ChatCompletionMessage{}
|
||||
message.Role = string(m.Role)
|
||||
message.Content = m.Content
|
||||
requestMessages = append(requestMessages, message)
|
||||
}
|
||||
}
|
||||
|
||||
request := openai.ChatCompletionRequest{
|
||||
Model: params.Model,
|
||||
MaxTokens: params.MaxTokens,
|
||||
Temperature: params.Temperature,
|
||||
Messages: requestMessages,
|
||||
N: 1, // limit responses to 1 "choice". we use choices[0] to reference it
|
||||
}
|
||||
|
||||
if len(params.ToolBag) > 0 {
|
||||
request.Tools = convertTools(params.ToolBag)
|
||||
request.ToolChoice = "auto"
|
||||
}
|
||||
|
||||
return request
|
||||
}
|
||||
|
||||
func handleToolCalls(
|
||||
params model.RequestParameters,
|
||||
content string,
|
||||
toolCalls []openai.ToolCall,
|
||||
) ([]model.Message, error) {
|
||||
toolCall := model.Message{
|
||||
Role: model.MessageRoleToolCall,
|
||||
Content: content,
|
||||
ToolCalls: convertToolCallToAPI(toolCalls),
|
||||
}
|
||||
|
||||
toolResults, err := tools.ExecuteToolCalls(toolCall.ToolCalls, params.ToolBag)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
toolResult := model.Message{
|
||||
Role: model.MessageRoleToolResult,
|
||||
ToolResults: toolResults,
|
||||
}
|
||||
|
||||
return []model.Message{toolCall, toolResult}, nil
|
||||
}
|
||||
|
||||
func (c *OpenAIClient) CreateChatCompletion(
|
||||
params model.RequestParameters,
|
||||
messages []model.Message,
|
||||
replies *[]model.Message,
|
||||
) (string, error) {
|
||||
client := openai.NewClient(c.APIKey)
|
||||
req := createChatCompletionRequest(c, params, messages)
|
||||
resp, err := client.CreateChatCompletion(context.Background(), req)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
choice := resp.Choices[0]
|
||||
|
||||
toolCalls := choice.Message.ToolCalls
|
||||
if len(toolCalls) > 0 {
|
||||
results, err := handleToolCalls(params, choice.Message.Content, toolCalls)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
if results != nil {
|
||||
*replies = append(*replies, results...)
|
||||
}
|
||||
|
||||
// Recurse into CreateChatCompletion with the tool call replies
|
||||
messages = append(messages, results...)
|
||||
return c.CreateChatCompletion(params, messages, replies)
|
||||
}
|
||||
|
||||
if replies != nil {
|
||||
*replies = append(*replies, model.Message{
|
||||
Role: model.MessageRoleAssistant,
|
||||
Content: choice.Message.Content,
|
||||
})
|
||||
}
|
||||
|
||||
// Return the user-facing message.
|
||||
return choice.Message.Content, nil
|
||||
}
|
||||
|
||||
func (c *OpenAIClient) CreateChatCompletionStream(
|
||||
params model.RequestParameters,
|
||||
messages []model.Message,
|
||||
replies *[]model.Message,
|
||||
output chan<- string,
|
||||
) (string, error) {
|
||||
client := openai.NewClient(c.APIKey)
|
||||
req := createChatCompletionRequest(c, params, messages)
|
||||
|
||||
stream, err := client.CreateChatCompletionStream(context.Background(), req)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
defer stream.Close()
|
||||
|
||||
content := strings.Builder{}
|
||||
toolCalls := []openai.ToolCall{}
|
||||
|
||||
// Iterate stream segments
|
||||
for {
|
||||
response, e := stream.Recv()
|
||||
if errors.Is(e, io.EOF) {
|
||||
break
|
||||
}
|
||||
|
||||
if e != nil {
|
||||
err = e
|
||||
break
|
||||
}
|
||||
|
||||
delta := response.Choices[0].Delta
|
||||
if len(delta.ToolCalls) > 0 {
|
||||
// Construct streamed tool_call arguments
|
||||
for _, tc := range delta.ToolCalls {
|
||||
if tc.Index == nil {
|
||||
return "", fmt.Errorf("Unexpected nil index for streamed tool call.")
|
||||
}
|
||||
if len(toolCalls) <= *tc.Index {
|
||||
toolCalls = append(toolCalls, tc)
|
||||
} else {
|
||||
toolCalls[*tc.Index].Function.Arguments += tc.Function.Arguments
|
||||
}
|
||||
}
|
||||
} else {
|
||||
output <- delta.Content
|
||||
content.WriteString(delta.Content)
|
||||
}
|
||||
}
|
||||
|
||||
if len(toolCalls) > 0 {
|
||||
results, err := handleToolCalls(params, content.String(), toolCalls)
|
||||
if err != nil {
|
||||
return content.String(), err
|
||||
}
|
||||
if results != nil {
|
||||
*replies = append(*replies, results...)
|
||||
}
|
||||
|
||||
// Recurse into CreateChatCompletionStream with the tool call replies
|
||||
messages = append(messages, results...)
|
||||
return c.CreateChatCompletionStream(params, messages, replies, output)
|
||||
}
|
||||
|
||||
if replies != nil {
|
||||
*replies = append(*replies, model.Message{
|
||||
Role: model.MessageRoleAssistant,
|
||||
Content: content.String(),
|
||||
})
|
||||
}
|
||||
|
||||
return content.String(), err
|
||||
}
|
@ -1,23 +0,0 @@
|
||||
package provider
|
||||
|
||||
import "git.mlow.ca/mlow/lmcli/pkg/lmcli/model"
|
||||
|
||||
type ChatCompletionClient interface {
|
||||
// CreateChatCompletion requests a response to the provided messages.
|
||||
// Replies are appended to the given replies struct, and the
|
||||
// complete user-facing response is returned as a string.
|
||||
CreateChatCompletion(
|
||||
params model.RequestParameters,
|
||||
messages []model.Message,
|
||||
replies *[]model.Message,
|
||||
) (string, error)
|
||||
|
||||
// Like CreateChageCompletion, except the response is streamed via
|
||||
// the output channel as it's received.
|
||||
CreateChatCompletionStream(
|
||||
params model.RequestParameters,
|
||||
messages []model.Message,
|
||||
replies *[]model.Message,
|
||||
output chan<- string,
|
||||
) (string, error)
|
||||
}
|
@ -1,121 +0,0 @@
|
||||
package lmcli
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"errors"
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"git.mlow.ca/mlow/lmcli/pkg/lmcli/model"
|
||||
sqids "github.com/sqids/sqids-go"
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
type ConversationStore interface {
|
||||
Conversations() ([]model.Conversation, error)
|
||||
|
||||
ConversationByShortName(shortName string) (*model.Conversation, error)
|
||||
ConversationShortNameCompletions(search string) []string
|
||||
|
||||
SaveConversation(conversation *model.Conversation) error
|
||||
DeleteConversation(conversation *model.Conversation) error
|
||||
|
||||
Messages(conversation *model.Conversation) ([]model.Message, error)
|
||||
LastMessage(conversation *model.Conversation) (*model.Message, error)
|
||||
|
||||
SaveMessage(message *model.Message) error
|
||||
DeleteMessage(message *model.Message) error
|
||||
UpdateMessage(message *model.Message) error
|
||||
}
|
||||
|
||||
type SQLStore struct {
|
||||
db *gorm.DB
|
||||
sqids *sqids.Sqids
|
||||
}
|
||||
|
||||
func NewSQLStore(db *gorm.DB) (*SQLStore, error) {
|
||||
models := []any{
|
||||
&model.Conversation{},
|
||||
&model.Message{},
|
||||
}
|
||||
|
||||
for _, x := range models {
|
||||
err := db.AutoMigrate(x)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("Could not perform database migrations: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
_sqids, _ := sqids.New(sqids.Options{MinLength: 4})
|
||||
return &SQLStore{db, _sqids}, nil
|
||||
}
|
||||
|
||||
func (s *SQLStore) SaveConversation(conversation *model.Conversation) error {
|
||||
err := s.db.Save(&conversation).Error
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if !conversation.ShortName.Valid {
|
||||
shortName, _ := s.sqids.Encode([]uint64{uint64(conversation.ID)})
|
||||
conversation.ShortName = sql.NullString{String: shortName, Valid: true}
|
||||
err = s.db.Save(&conversation).Error
|
||||
}
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
func (s *SQLStore) DeleteConversation(conversation *model.Conversation) error {
|
||||
s.db.Where("conversation_id = ?", conversation.ID).Delete(&model.Message{})
|
||||
return s.db.Delete(&conversation).Error
|
||||
}
|
||||
|
||||
func (s *SQLStore) SaveMessage(message *model.Message) error {
|
||||
return s.db.Create(message).Error
|
||||
}
|
||||
|
||||
func (s *SQLStore) DeleteMessage(message *model.Message) error {
|
||||
return s.db.Delete(&message).Error
|
||||
}
|
||||
|
||||
func (s *SQLStore) UpdateMessage(message *model.Message) error {
|
||||
return s.db.Updates(&message).Error
|
||||
}
|
||||
|
||||
func (s *SQLStore) Conversations() ([]model.Conversation, error) {
|
||||
var conversations []model.Conversation
|
||||
err := s.db.Find(&conversations).Error
|
||||
return conversations, err
|
||||
}
|
||||
|
||||
func (s *SQLStore) ConversationShortNameCompletions(shortName string) []string {
|
||||
var completions []string
|
||||
conversations, _ := s.Conversations() // ignore error for completions
|
||||
for _, conversation := range conversations {
|
||||
if shortName == "" || strings.HasPrefix(conversation.ShortName.String, shortName) {
|
||||
completions = append(completions, fmt.Sprintf("%s\t%s", conversation.ShortName.String, conversation.Title))
|
||||
}
|
||||
}
|
||||
return completions
|
||||
}
|
||||
|
||||
func (s *SQLStore) ConversationByShortName(shortName string) (*model.Conversation, error) {
|
||||
if shortName == "" {
|
||||
return nil, errors.New("shortName is empty")
|
||||
}
|
||||
var conversation model.Conversation
|
||||
err := s.db.Where("short_name = ?", shortName).Find(&conversation).Error
|
||||
return &conversation, err
|
||||
}
|
||||
|
||||
func (s *SQLStore) Messages(conversation *model.Conversation) ([]model.Message, error) {
|
||||
var messages []model.Message
|
||||
err := s.db.Where("conversation_id = ?", conversation.ID).Find(&messages).Error
|
||||
return messages, err
|
||||
}
|
||||
|
||||
func (s *SQLStore) LastMessage(conversation *model.Conversation) (*model.Message, error) {
|
||||
var message model.Message
|
||||
err := s.db.Where("conversation_id = ?", conversation.ID).Last(&message).Error
|
||||
return &message, err
|
||||
}
|
@ -1,114 +0,0 @@
|
||||
package tools
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"os"
|
||||
"strings"
|
||||
|
||||
toolutil "git.mlow.ca/mlow/lmcli/pkg/lmcli/tools/util"
|
||||
"git.mlow.ca/mlow/lmcli/pkg/lmcli/model"
|
||||
)
|
||||
|
||||
const FILE_INSERT_LINES_DESCRIPTION = `Insert lines into a file, must specify path.
|
||||
|
||||
Make sure your inserts match the flow and indentation of surrounding content.`
|
||||
|
||||
var FileInsertLinesTool = model.Tool{
|
||||
Name: "file_insert_lines",
|
||||
Description: FILE_INSERT_LINES_DESCRIPTION,
|
||||
Parameters: []model.ToolParameter{
|
||||
{
|
||||
Name: "path",
|
||||
Type: "string",
|
||||
Description: "Path of the file to be modified, relative to the current working directory.",
|
||||
Required: true,
|
||||
},
|
||||
{
|
||||
Name: "position",
|
||||
Type: "integer",
|
||||
Description: `Which line to insert content *before*.`,
|
||||
Required: true,
|
||||
},
|
||||
{
|
||||
Name: "content",
|
||||
Type: "string",
|
||||
Description: `The content to insert.`,
|
||||
Required: true,
|
||||
},
|
||||
},
|
||||
Impl: func(tool *model.Tool, args map[string]interface{}) (string, error) {
|
||||
tmp, ok := args["path"]
|
||||
if !ok {
|
||||
return "", fmt.Errorf("path parameter to write_file was not included.")
|
||||
}
|
||||
path, ok := tmp.(string)
|
||||
if !ok {
|
||||
return "", fmt.Errorf("Invalid path in function arguments: %v", tmp)
|
||||
}
|
||||
var position int
|
||||
tmp, ok = args["position"]
|
||||
if ok {
|
||||
tmp, ok := tmp.(float64)
|
||||
if !ok {
|
||||
return "", fmt.Errorf("Invalid position in function arguments: %v", tmp)
|
||||
}
|
||||
position = int(tmp)
|
||||
}
|
||||
var content string
|
||||
tmp, ok = args["content"]
|
||||
if ok {
|
||||
content, ok = tmp.(string)
|
||||
if !ok {
|
||||
return "", fmt.Errorf("Invalid content in function arguments: %v", tmp)
|
||||
}
|
||||
}
|
||||
|
||||
result := fileInsertLines(path, position, content)
|
||||
ret, err := result.ToJson()
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("Could not serialize result: %v", err)
|
||||
}
|
||||
return ret, nil
|
||||
},
|
||||
}
|
||||
|
||||
func fileInsertLines(path string, position int, content string) model.CallResult {
|
||||
ok, reason := toolutil.IsPathWithinCWD(path)
|
||||
if !ok {
|
||||
return model.CallResult{Message: reason}
|
||||
}
|
||||
|
||||
// Read the existing file's content
|
||||
data, err := os.ReadFile(path)
|
||||
if err != nil {
|
||||
if !os.IsNotExist(err) {
|
||||
return model.CallResult{Message: fmt.Sprintf("Could not read path: %s", err.Error())}
|
||||
}
|
||||
_, err = os.Create(path)
|
||||
if err != nil {
|
||||
return model.CallResult{Message: fmt.Sprintf("Could not create new file: %s", err.Error())}
|
||||
}
|
||||
data = []byte{}
|
||||
}
|
||||
|
||||
if position < 1 {
|
||||
return model.CallResult{Message: "start_line cannot be less than 1"}
|
||||
}
|
||||
|
||||
lines := strings.Split(string(data), "\n")
|
||||
contentLines := strings.Split(strings.Trim(content, "\n"), "\n")
|
||||
|
||||
before := lines[:position-1]
|
||||
after := lines[position-1:]
|
||||
lines = append(before, append(contentLines, after...)...)
|
||||
|
||||
newContent := strings.Join(lines, "\n")
|
||||
|
||||
// Join the lines and write back to the file
|
||||
err = os.WriteFile(path, []byte(newContent), 0644)
|
||||
if err != nil {
|
||||
return model.CallResult{Message: fmt.Sprintf("Could not write to path: %s", err.Error())}
|
||||
}
|
||||
|
||||
return model.CallResult{Result: newContent}
|
||||
}
|
@ -1,133 +0,0 @@
|
||||
package tools
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"os"
|
||||
"strings"
|
||||
|
||||
"git.mlow.ca/mlow/lmcli/pkg/lmcli/model"
|
||||
toolutil "git.mlow.ca/mlow/lmcli/pkg/lmcli/tools/util"
|
||||
)
|
||||
|
||||
const FILE_REPLACE_LINES_DESCRIPTION = `Replace or remove a range of lines within a file, must specify path.
|
||||
|
||||
Useful for re-writing snippets/blocks of code or entire functions.
|
||||
|
||||
Plan your edits carefully and ensure any new content matches the flow and indentation of surrounding text.`
|
||||
|
||||
var FileReplaceLinesTool = model.Tool{
|
||||
Name: "file_replace_lines",
|
||||
Description: FILE_REPLACE_LINES_DESCRIPTION,
|
||||
Parameters: []model.ToolParameter{
|
||||
{
|
||||
Name: "path",
|
||||
Type: "string",
|
||||
Description: "Path of the file to be modified, relative to the current working directory.",
|
||||
Required: true,
|
||||
},
|
||||
{
|
||||
Name: "start_line",
|
||||
Type: "integer",
|
||||
Description: `Line number which specifies the start of the replacement range (inclusive).`,
|
||||
Required: true,
|
||||
},
|
||||
{
|
||||
Name: "end_line",
|
||||
Type: "integer",
|
||||
Description: `Line number which specifies the end of the replacement range (inclusive). If unset, range extends to end of file.`,
|
||||
},
|
||||
{
|
||||
Name: "content",
|
||||
Type: "string",
|
||||
Description: `Content to replace specified range. Omit to remove the specified range.`,
|
||||
},
|
||||
},
|
||||
Impl: func(tool *model.Tool, args map[string]interface{}) (string, error) {
|
||||
tmp, ok := args["path"]
|
||||
if !ok {
|
||||
return "", fmt.Errorf("path parameter to write_file was not included.")
|
||||
}
|
||||
path, ok := tmp.(string)
|
||||
if !ok {
|
||||
return "", fmt.Errorf("Invalid path in function arguments: %v", tmp)
|
||||
}
|
||||
var start_line int
|
||||
tmp, ok = args["start_line"]
|
||||
if ok {
|
||||
tmp, ok := tmp.(float64)
|
||||
if !ok {
|
||||
return "", fmt.Errorf("Invalid start_line in function arguments: %v", tmp)
|
||||
}
|
||||
start_line = int(tmp)
|
||||
}
|
||||
var end_line int
|
||||
tmp, ok = args["end_line"]
|
||||
if ok {
|
||||
tmp, ok := tmp.(float64)
|
||||
if !ok {
|
||||
return "", fmt.Errorf("Invalid end_line in function arguments: %v", tmp)
|
||||
}
|
||||
end_line = int(tmp)
|
||||
}
|
||||
var content string
|
||||
tmp, ok = args["content"]
|
||||
if ok {
|
||||
content, ok = tmp.(string)
|
||||
if !ok {
|
||||
return "", fmt.Errorf("Invalid content in function arguments: %v", tmp)
|
||||
}
|
||||
}
|
||||
|
||||
result := fileReplaceLines(path, start_line, end_line, content)
|
||||
ret, err := result.ToJson()
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("Could not serialize result: %v", err)
|
||||
}
|
||||
return ret, nil
|
||||
},
|
||||
}
|
||||
|
||||
func fileReplaceLines(path string, startLine int, endLine int, content string) model.CallResult {
|
||||
ok, reason := toolutil.IsPathWithinCWD(path)
|
||||
if !ok {
|
||||
return model.CallResult{Message: reason}
|
||||
}
|
||||
|
||||
// Read the existing file's content
|
||||
data, err := os.ReadFile(path)
|
||||
if err != nil {
|
||||
if !os.IsNotExist(err) {
|
||||
return model.CallResult{Message: fmt.Sprintf("Could not read path: %s", err.Error())}
|
||||
}
|
||||
_, err = os.Create(path)
|
||||
if err != nil {
|
||||
return model.CallResult{Message: fmt.Sprintf("Could not create new file: %s", err.Error())}
|
||||
}
|
||||
data = []byte{}
|
||||
}
|
||||
|
||||
if startLine < 1 {
|
||||
return model.CallResult{Message: "start_line cannot be less than 1"}
|
||||
}
|
||||
|
||||
lines := strings.Split(string(data), "\n")
|
||||
contentLines := strings.Split(strings.Trim(content, "\n"), "\n")
|
||||
|
||||
if endLine == 0 || endLine > len(lines) {
|
||||
endLine = len(lines)
|
||||
}
|
||||
|
||||
before := lines[:startLine-1]
|
||||
after := lines[endLine:]
|
||||
|
||||
lines = append(before, append(contentLines, after...)...)
|
||||
newContent := strings.Join(lines, "\n")
|
||||
|
||||
// Join the lines and write back to the file
|
||||
err = os.WriteFile(path, []byte(newContent), 0644)
|
||||
if err != nil {
|
||||
return model.CallResult{Message: fmt.Sprintf("Could not write to path: %s", err.Error())}
|
||||
}
|
||||
|
||||
return model.CallResult{Result: newContent}
|
||||
}
|
@ -1,100 +0,0 @@
|
||||
package tools
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
|
||||
"git.mlow.ca/mlow/lmcli/pkg/lmcli/model"
|
||||
toolutil "git.mlow.ca/mlow/lmcli/pkg/lmcli/tools/util"
|
||||
)
|
||||
|
||||
const READ_DIR_DESCRIPTION = `Return the contents of the CWD (current working directory).
|
||||
|
||||
Example result:
|
||||
{
|
||||
"message": "success",
|
||||
"result": [
|
||||
{"name": "a_file.txt", "type": "file", "size": 123},
|
||||
{"name": "a_directory/", "type": "dir", "size": 11},
|
||||
...
|
||||
]
|
||||
}
|
||||
|
||||
For files, size represents the size of the file, in bytes.
|
||||
For directories, size represents the number of entries in that directory.`
|
||||
|
||||
var ReadDirTool = model.Tool{
|
||||
Name: "read_dir",
|
||||
Description: READ_DIR_DESCRIPTION,
|
||||
Parameters: []model.ToolParameter{
|
||||
{
|
||||
Name: "relative_dir",
|
||||
Type: "string",
|
||||
Description: "If set, read the contents of a directory relative to the current one.",
|
||||
},
|
||||
},
|
||||
Impl: func(tool *model.Tool, args map[string]interface{}) (string, error) {
|
||||
var relativeDir string
|
||||
tmp, ok := args["relative_dir"]
|
||||
if ok {
|
||||
relativeDir, ok = tmp.(string)
|
||||
if !ok {
|
||||
return "", fmt.Errorf("Invalid relative_dir in function arguments: %v", tmp)
|
||||
}
|
||||
}
|
||||
result := readDir(relativeDir)
|
||||
ret, err := result.ToJson()
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("Could not serialize result: %v", err)
|
||||
}
|
||||
return ret, nil
|
||||
},
|
||||
}
|
||||
|
||||
func readDir(path string) model.CallResult {
|
||||
if path == "" {
|
||||
path = "."
|
||||
}
|
||||
ok, reason := toolutil.IsPathWithinCWD(path)
|
||||
if !ok {
|
||||
return model.CallResult{Message: reason}
|
||||
}
|
||||
|
||||
files, err := os.ReadDir(path)
|
||||
if err != nil {
|
||||
return model.CallResult{
|
||||
Message: err.Error(),
|
||||
}
|
||||
}
|
||||
|
||||
var dirContents []map[string]interface{}
|
||||
for _, f := range files {
|
||||
info, _ := f.Info()
|
||||
|
||||
name := f.Name()
|
||||
if strings.HasPrefix(name, ".") {
|
||||
// skip hidden files
|
||||
continue
|
||||
}
|
||||
|
||||
entryType := "file"
|
||||
size := info.Size()
|
||||
|
||||
if info.IsDir() {
|
||||
name += "/"
|
||||
entryType = "dir"
|
||||
subdirfiles, _ := os.ReadDir(filepath.Join(".", path, info.Name()))
|
||||
size = int64(len(subdirfiles))
|
||||
}
|
||||
|
||||
dirContents = append(dirContents, map[string]interface{}{
|
||||
"name": name,
|
||||
"type": entryType,
|
||||
"size": size,
|
||||
})
|
||||
}
|
||||
|
||||
return model.CallResult{Result: dirContents}
|
||||
}
|
@ -1,71 +0,0 @@
|
||||
package tools
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"os"
|
||||
"strings"
|
||||
|
||||
"git.mlow.ca/mlow/lmcli/pkg/lmcli/model"
|
||||
toolutil "git.mlow.ca/mlow/lmcli/pkg/lmcli/tools/util"
|
||||
)
|
||||
|
||||
const READ_FILE_DESCRIPTION = `Read the contents of a text file relative to the current working directory.
|
||||
|
||||
Each line of the returned content is prefixed with its line number and a tab (\t).
|
||||
|
||||
Example result:
|
||||
{
|
||||
"message": "success",
|
||||
"result": "1\tthe contents\n2\tof the file\n"
|
||||
}`
|
||||
|
||||
var ReadFileTool = model.Tool{
|
||||
Name: "read_file",
|
||||
Description: READ_FILE_DESCRIPTION,
|
||||
Parameters: []model.ToolParameter{
|
||||
{
|
||||
Name: "path",
|
||||
Type: "string",
|
||||
Description: "Path to a file within the current working directory to read.",
|
||||
Required: true,
|
||||
},
|
||||
},
|
||||
|
||||
Impl: func(tool *model.Tool, args map[string]interface{}) (string, error) {
|
||||
tmp, ok := args["path"]
|
||||
if !ok {
|
||||
return "", fmt.Errorf("Path parameter to read_file was not included.")
|
||||
}
|
||||
path, ok := tmp.(string)
|
||||
if !ok {
|
||||
return "", fmt.Errorf("Invalid path in function arguments: %v", tmp)
|
||||
}
|
||||
result := readFile(path)
|
||||
ret, err := result.ToJson()
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("Could not serialize result: %v", err)
|
||||
}
|
||||
return ret, nil
|
||||
},
|
||||
}
|
||||
|
||||
func readFile(path string) model.CallResult {
|
||||
ok, reason := toolutil.IsPathWithinCWD(path)
|
||||
if !ok {
|
||||
return model.CallResult{Message: reason}
|
||||
}
|
||||
data, err := os.ReadFile(path)
|
||||
if err != nil {
|
||||
return model.CallResult{Message: fmt.Sprintf("Could not read path: %s", err.Error())}
|
||||
}
|
||||
|
||||
lines := strings.Split(string(data), "\n")
|
||||
content := strings.Builder{}
|
||||
for i, line := range lines {
|
||||
content.WriteString(fmt.Sprintf("%d\t%s\n", i+1, line))
|
||||
}
|
||||
|
||||
return model.CallResult{
|
||||
Result: content.String(),
|
||||
}
|
||||
}
|
@ -1,51 +0,0 @@
|
||||
package tools
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"os"
|
||||
|
||||
"git.mlow.ca/mlow/lmcli/pkg/lmcli/model"
|
||||
)
|
||||
|
||||
var AvailableTools map[string]model.Tool = map[string]model.Tool{
|
||||
"read_dir": ReadDirTool,
|
||||
"read_file": ReadFileTool,
|
||||
"write_file": WriteFileTool,
|
||||
"file_insert_lines": FileInsertLinesTool,
|
||||
"file_replace_lines": FileReplaceLinesTool,
|
||||
}
|
||||
|
||||
func ExecuteToolCalls(toolCalls []model.ToolCall, toolBag []model.Tool) ([]model.ToolResult, error) {
|
||||
var toolResults []model.ToolResult
|
||||
for _, toolCall := range toolCalls {
|
||||
var tool *model.Tool
|
||||
for _, available := range toolBag {
|
||||
if available.Name == toolCall.Name {
|
||||
tool = &available
|
||||
break
|
||||
}
|
||||
}
|
||||
if tool == nil {
|
||||
return nil, fmt.Errorf("Requested tool '%s' does not exist. Hallucination?", toolCall.Name)
|
||||
}
|
||||
|
||||
// TODO: ability to silence this
|
||||
fmt.Fprintf(os.Stderr, "\nINFO: Executing tool '%s' with args %s\n", toolCall.Name, toolCall.Parameters)
|
||||
|
||||
// Execute the tool
|
||||
result, err := tool.Impl(tool, toolCall.Parameters)
|
||||
if err != nil {
|
||||
// This can happen if the model missed or supplied invalid tool args
|
||||
return nil, fmt.Errorf("Tool '%s' error: %v\n", toolCall.Name, err)
|
||||
}
|
||||
|
||||
toolResult := model.ToolResult{
|
||||
ToolCallID: toolCall.ID,
|
||||
ToolName: toolCall.Name,
|
||||
Result: result,
|
||||
}
|
||||
|
||||
toolResults = append(toolResults, toolResult)
|
||||
}
|
||||
return toolResults, nil
|
||||
}
|
@ -1,67 +0,0 @@
|
||||
package util
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// isPathContained attempts to verify whether `path` is the same as or
|
||||
// contained within `directory`. It is overly cautious, returning false even if
|
||||
// `path` IS contained within `directory`, but the two paths use different
|
||||
// casing, and we happen to be on a case-insensitive filesystem.
|
||||
// This is ultimately to attempt to stop an LLM from going outside of where I
|
||||
// tell it to. Additional layers of security should be considered.. run in a
|
||||
// VM/container.
|
||||
func IsPathContained(directory string, path string) (bool, error) {
|
||||
// Clean and resolve symlinks for both paths
|
||||
path, err := filepath.Abs(path)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
|
||||
// check if path exists
|
||||
_, err = os.Stat(path)
|
||||
if err != nil {
|
||||
if !os.IsNotExist(err) {
|
||||
return false, fmt.Errorf("Could not stat path: %v", err)
|
||||
}
|
||||
} else {
|
||||
path, err = filepath.EvalSymlinks(path)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
}
|
||||
|
||||
directory, err = filepath.Abs(directory)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
directory, err = filepath.EvalSymlinks(directory)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
|
||||
// Case insensitive checks
|
||||
if !strings.EqualFold(path, directory) &&
|
||||
!strings.HasPrefix(strings.ToLower(path), strings.ToLower(directory)+string(os.PathSeparator)) {
|
||||
return false, nil
|
||||
}
|
||||
|
||||
return true, nil
|
||||
}
|
||||
|
||||
func IsPathWithinCWD(path string) (bool, string) {
|
||||
cwd, err := os.Getwd()
|
||||
if err != nil {
|
||||
return false, "Failed to determine current working directory"
|
||||
}
|
||||
if ok, err := IsPathContained(cwd, path); !ok {
|
||||
if err != nil {
|
||||
return false, fmt.Sprintf("Could not determine whether path '%s' is within the current working directory: %s", path, err.Error())
|
||||
}
|
||||
return false, fmt.Sprintf("Path '%s' is not within the current working directory", path)
|
||||
}
|
||||
return true, ""
|
||||
}
|
@ -1,71 +0,0 @@
|
||||
package tools
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"os"
|
||||
|
||||
"git.mlow.ca/mlow/lmcli/pkg/lmcli/model"
|
||||
toolutil "git.mlow.ca/mlow/lmcli/pkg/lmcli/tools/util"
|
||||
)
|
||||
|
||||
const WRITE_FILE_DESCRIPTION = `Write the provided contents to a file relative to the current working directory.
|
||||
|
||||
Example result:
|
||||
{
|
||||
"message": "success"
|
||||
}`
|
||||
|
||||
var WriteFileTool = model.Tool{
|
||||
Name: "write_file",
|
||||
Description: WRITE_FILE_DESCRIPTION,
|
||||
Parameters: []model.ToolParameter{
|
||||
{
|
||||
Name: "path",
|
||||
Type: "string",
|
||||
Description: "Path to a file within the current working directory to write to.",
|
||||
Required: true,
|
||||
},
|
||||
{
|
||||
Name: "content",
|
||||
Type: "string",
|
||||
Description: "The content to write to the file. Overwrites any existing content!",
|
||||
Required: true,
|
||||
},
|
||||
},
|
||||
Impl: func(t *model.Tool, args map[string]interface{}) (string, error) {
|
||||
tmp, ok := args["path"]
|
||||
if !ok {
|
||||
return "", fmt.Errorf("Path parameter to write_file was not included.")
|
||||
}
|
||||
path, ok := tmp.(string)
|
||||
if !ok {
|
||||
return "", fmt.Errorf("Invalid path in function arguments: %v", tmp)
|
||||
}
|
||||
tmp, ok = args["content"]
|
||||
if !ok {
|
||||
return "", fmt.Errorf("Content parameter to write_file was not included.")
|
||||
}
|
||||
content, ok := tmp.(string)
|
||||
if !ok {
|
||||
return "", fmt.Errorf("Invalid content in function arguments: %v", tmp)
|
||||
}
|
||||
result := writeFile(path, content)
|
||||
ret, err := result.ToJson()
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("Could not serialize result: %v", err)
|
||||
}
|
||||
return ret, nil
|
||||
},
|
||||
}
|
||||
|
||||
func writeFile(path string, content string) model.CallResult {
|
||||
ok, reason := toolutil.IsPathWithinCWD(path)
|
||||
if !ok {
|
||||
return model.CallResult{Message: reason}
|
||||
}
|
||||
err := os.WriteFile(path, []byte(content), 0644)
|
||||
if err != nil {
|
||||
return model.CallResult{Message: fmt.Sprintf("Could not write to path: %s", err.Error())}
|
||||
}
|
||||
return model.CallResult{}
|
||||
}
|
@ -1,60 +0,0 @@
|
||||
package tty
|
||||
|
||||
import (
|
||||
"io"
|
||||
"strings"
|
||||
|
||||
"github.com/alecthomas/chroma/v2"
|
||||
"github.com/alecthomas/chroma/v2/formatters"
|
||||
"github.com/alecthomas/chroma/v2/lexers"
|
||||
"github.com/alecthomas/chroma/v2/styles"
|
||||
)
|
||||
|
||||
type ChromaHighlighter struct {
|
||||
lexer chroma.Lexer
|
||||
formatter chroma.Formatter
|
||||
style *chroma.Style
|
||||
}
|
||||
|
||||
func NewChromaHighlighter(lang, format, style string) *ChromaHighlighter {
|
||||
l := lexers.Get(lang)
|
||||
if l == nil {
|
||||
l = lexers.Fallback
|
||||
}
|
||||
l = chroma.Coalesce(l)
|
||||
|
||||
f := formatters.Get(format)
|
||||
if f == nil {
|
||||
f = formatters.Fallback
|
||||
}
|
||||
|
||||
s := styles.Get(style)
|
||||
if s == nil {
|
||||
s = styles.Fallback
|
||||
}
|
||||
|
||||
return &ChromaHighlighter{
|
||||
lexer: l,
|
||||
formatter: f,
|
||||
style: s,
|
||||
}
|
||||
}
|
||||
|
||||
func (s *ChromaHighlighter) Highlight(w io.Writer, text string) error {
|
||||
it, err := s.lexer.Tokenise(nil, text)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return s.formatter.Format(w, s.style, it)
|
||||
}
|
||||
|
||||
func (s *ChromaHighlighter) HighlightS(text string) (string, error) {
|
||||
it, err := s.lexer.Tokenise(nil, text)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
sb := strings.Builder{}
|
||||
sb.Grow(len(text) * 2)
|
||||
s.formatter.Format(&sb, s.style, it)
|
||||
return sb.String(), nil
|
||||
}
|
Loading…
Reference in New Issue
Block a user