From 8fe2a2cf5350a99fa7a3c3c8cea35823e13fadc0 Mon Sep 17 00:00:00 2001 From: Matt Low Date: Fri, 3 Nov 2023 16:56:20 +0000 Subject: [PATCH] Add initial store.go for conversation/message persistence --- go.mod | 6 ++++ go.sum | 12 ++++++++ main.go | 22 +++++--------- store.go | 87 ++++++++++++++++++++++++++++++++++++++++++++++++++++++++ 4 files changed, 112 insertions(+), 15 deletions(-) create mode 100644 store.go diff --git a/go.mod b/go.mod index 8a6a052..ff03af0 100644 --- a/go.mod +++ b/go.mod @@ -5,9 +5,15 @@ go 1.19 require ( github.com/sashabaranov/go-openai v1.16.0 github.com/spf13/cobra v1.7.0 + github.com/sqids/sqids-go v0.4.1 + gorm.io/driver/sqlite v1.5.4 + gorm.io/gorm v1.25.5 ) require ( github.com/inconshreveable/mousetrap v1.1.0 // indirect + github.com/jinzhu/inflection v1.0.0 // indirect + github.com/jinzhu/now v1.1.5 // indirect + github.com/mattn/go-sqlite3 v1.14.17 // indirect github.com/spf13/pflag v1.0.5 // indirect ) diff --git a/go.sum b/go.sum index 3d2c5dd..687200b 100644 --- a/go.sum +++ b/go.sum @@ -1,6 +1,12 @@ github.com/cpuguy83/go-md2man/v2 v2.0.2/go.mod h1:tgQtvFlXSQOSOSIRvRPT7W67SCa46tRHOmNcaadrF8o= github.com/inconshreveable/mousetrap v1.1.0 h1:wN+x4NVGpMsO7ErUn/mUI3vEoE6Jt13X2s0bqwp9tc8= github.com/inconshreveable/mousetrap v1.1.0/go.mod h1:vpF70FUmC8bwa3OWnCshd2FqLfsEA9PFc4w1p2J65bw= +github.com/jinzhu/inflection v1.0.0 h1:K317FqzuhWc8YvSVlFMCCUb36O/S9MCKRDI7QkRKD/E= +github.com/jinzhu/inflection v1.0.0/go.mod h1:h+uFLlag+Qp1Va5pdKtLDYj+kHp5pxUVkryuEj+Srlc= +github.com/jinzhu/now v1.1.5 h1:/o9tlHleP7gOFmsnYNz3RGnqzefHA47wQpKrrdTIwXQ= +github.com/jinzhu/now v1.1.5/go.mod h1:d3SSVoowX0Lcu0IBviAWJpolVfI5UJVZZ7cO71lE/z8= +github.com/mattn/go-sqlite3 v1.14.17 h1:mCRHCLDUBXgpKAqIKsaAaAsrAlbkeomtRFKXh2L6YIM= +github.com/mattn/go-sqlite3 v1.14.17/go.mod h1:2eHXhiwb8IkHr+BDWZGa96P6+rkvnG63S2DGjv9HUNg= github.com/russross/blackfriday/v2 v2.1.0/go.mod h1:+Rmxgy9KzJVeS9/2gXHxylqXiyQDYRxCVz55jmeOWTM= github.com/sashabaranov/go-openai v1.16.0 h1:34W6WV84ey6OpW0p2UewZkdMu82AxGC+BzpU6iiauRw= github.com/sashabaranov/go-openai v1.16.0/go.mod h1:lj5b/K+zjTSFxVLijLSTDZuP7adOgerWeFyZLUhAKRg= @@ -8,5 +14,11 @@ github.com/spf13/cobra v1.7.0 h1:hyqWnYt1ZQShIddO5kBpj3vu05/++x6tJ6dg8EC572I= github.com/spf13/cobra v1.7.0/go.mod h1:uLxZILRyS/50WlhOIKD7W6V5bgeIt+4sICxh6uRMrb0= github.com/spf13/pflag v1.0.5 h1:iy+VFUOCP1a+8yFto/drg2CJ5u0yRoB7fZw3DKv/JXA= github.com/spf13/pflag v1.0.5/go.mod h1:McXfInJRrz4CZXVZOBLb0bTZqETkiAhM9Iw0y3An2Bg= +github.com/sqids/sqids-go v0.4.1 h1:eQKYzmAZbLlRwHeHYPF35QhgxwZHLnlmVj9AkIj/rrw= +github.com/sqids/sqids-go v0.4.1/go.mod h1:EMwHuPQgSNFS0A49jESTfIQS+066XQTVhukrzEPScl8= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= +gorm.io/driver/sqlite v1.5.4 h1:IqXwXi8M/ZlPzH/947tn5uik3aYQslP9BVveoax0nV0= +gorm.io/driver/sqlite v1.5.4/go.mod h1:qxAuCol+2r6PannQDpOP1FP6ag3mKi4esLnB/jHed+4= +gorm.io/gorm v1.25.5 h1:zR9lOiiYf09VNh5Q1gphfyia1JpiClIWG9hQaxB/mls= +gorm.io/gorm v1.25.5/go.mod h1:hbnx/Oo0ChWMn1BIhpy1oYozzpM15i4YPuHDmfYtwg8= diff --git a/main.go b/main.go index bd5d550..cf43819 100644 --- a/main.go +++ b/main.go @@ -7,21 +7,13 @@ import ( "github.com/spf13/cobra" ) -type Message struct { - MessageID string - ConversationID string - Conversation Conversation - OriginalContent string - Role string // 'user' or 'assistant' -} +var store, storeError = InitializeStore() -type Conversation struct { - ID string - Title string -} - -type Context struct { - CurrentConversation string +func checkStore() { + if storeError != nil { + fmt.Fprintf(os.Stderr, "Error establishing connection to store: %v\n", storeError) + os.Exit(1) + } } var rootCmd = &cobra.Command{ @@ -29,6 +21,7 @@ var rootCmd = &cobra.Command{ Short: "Interact with Large Language Models", Long: `lm is a CLI tool to interact with OpenAI's GPT 3.5 and GPT 4.`, Run: func(cmd *cobra.Command, args []string) { + checkStore() // execute `lm ls` by default }, } @@ -100,7 +93,6 @@ var newCmd = &cobra.Command{ fmt.Printf("> %s\n", messageContents) - // Initialize the messages array for this conversation. messages := []Message{ { OriginalContent: messageContents, diff --git a/store.go b/store.go new file mode 100644 index 0000000..2712edb --- /dev/null +++ b/store.go @@ -0,0 +1,87 @@ +package main + +import ( + "database/sql" + "gorm.io/gorm" + "gorm.io/driver/sqlite" + sqids "github.com/sqids/sqids-go" +) + +type Store struct { + db *gorm.DB + sqids *sqids.Sqids +} + +type Message struct { + ID uint `gorm:"primaryKey"` + ConversationID uint `gorm:"foreignKey:ConversationID"` + Conversation Conversation + OriginalContent string + Role string // 'user' or 'assistant' +} + +type Conversation struct { + ID uint `gorm:"primaryKey"` + ShortName sql.NullString + Title string +} + +const ( + DATABASE_FILE string = "./data.db" +) + +func InitializeStore() (*Store, error) { + db, err := gorm.Open(sqlite.Open(DATABASE_FILE), &gorm.Config{}) + if err != nil { + return nil, err + } + + models := []any{ + &Conversation{}, + &Message{}, + } + + for _, x := range(models) { + err := db.AutoMigrate(x) + if err != nil { + return nil, err + } + } + + _sqids, _ := sqids.New(sqids.Options{ + MinLength: 4, + }) + + return &Store{db: db, sqids: _sqids}, nil +} + +func (s *Store) SaveConversation(conversation *Conversation) error { + err := s.db.Save(&conversation).Error + if err != nil { + return err + } + + if !conversation.ShortName.Valid { + shortName, _ := s.sqids.Encode([]uint64{ uint64(conversation.ID) }) + conversation.ShortName = sql.NullString{String: shortName, Valid: true} + err = s.db.Save(&conversation).Error + } + + return err +} + +func (s *Store) SaveMessage(message *Message) error { + return s.db.Create(message).Error +} + +func (s *Store) GetConversations() ([]Conversation, error) { + var conversations []Conversation + err := s.db.Find(&conversations).Error + return conversations, err +} + +func (s *Store) GetMessages(conversation *Conversation) ([]Message, error) { + var messages []Message + err := s.db.Where("conversation_id = ?", conversation.ID).Find(&messages).Error + return messages, err +}