lmcli/pkg/cli/store.go

101 lines
2.2 KiB
Go
Raw Normal View History

package cli
import (
"database/sql"
"fmt"
"os"
"path/filepath"
2023-11-04 16:56:22 -06:00
sqids "github.com/sqids/sqids-go"
2023-11-04 16:56:22 -06:00
"gorm.io/driver/sqlite"
"gorm.io/gorm"
)
type Store struct {
2023-11-04 16:56:22 -06:00
db *gorm.DB
sqids *sqids.Sqids
}
type Message struct {
2023-11-04 16:56:22 -06:00
ID uint `gorm:"primaryKey"`
ConversationID uint `gorm:"foreignKey:ConversationID"`
Conversation Conversation
OriginalContent string
2023-11-04 16:56:22 -06:00
Role string // 'user' or 'assistant'
}
type Conversation struct {
2023-11-04 16:56:22 -06:00
ID uint `gorm:"primaryKey"`
ShortName sql.NullString
2023-11-04 16:56:22 -06:00
Title string
}
func getDataDir() string {
2023-11-04 16:56:22 -06:00
var dataDir string
xdgDataHome := os.Getenv("XDG_DATA_HOME")
if xdgDataHome != "" {
dataDir = filepath.Join(xdgDataHome, "lmcli")
} else {
userHomeDir, _ := os.UserHomeDir()
dataDir = filepath.Join(userHomeDir, ".local/share/lmcli")
}
os.MkdirAll(dataDir, 0755)
return dataDir
}
func NewStore() (*Store, error) {
databaseFile := filepath.Join(getDataDir(), "conversations.db")
db, err := gorm.Open(sqlite.Open(databaseFile), &gorm.Config{})
if err != nil {
return nil, fmt.Errorf("Error establishing connection to store: %v", err)
}
models := []any{
&Conversation{},
&Message{},
}
2023-11-04 16:56:22 -06:00
for _, x := range models {
err := db.AutoMigrate(x)
if err != nil {
return nil, fmt.Errorf("Could not perform database migrations: %v", err)
}
}
_sqids, _ := sqids.New(sqids.Options{MinLength: 4})
return &Store{db, _sqids}, nil
}
func (s *Store) SaveConversation(conversation *Conversation) error {
err := s.db.Save(&conversation).Error
if err != nil {
return err
}
if !conversation.ShortName.Valid {
2023-11-04 16:56:22 -06:00
shortName, _ := s.sqids.Encode([]uint64{uint64(conversation.ID)})
conversation.ShortName = sql.NullString{String: shortName, Valid: true}
err = s.db.Save(&conversation).Error
}
return err
}
func (s *Store) SaveMessage(message *Message) error {
return s.db.Create(message).Error
}
func (s *Store) GetConversations() ([]Conversation, error) {
var conversations []Conversation
err := s.db.Find(&conversations).Error
return conversations, err
}
func (s *Store) GetMessages(conversation *Conversation) ([]Message, error) {
var messages []Message
err := s.db.Where("conversation_id = ?", conversation.ID).Find(&messages).Error
return messages, err
}