2023-11-04 12:20:13 -06:00
|
|
|
package cli
|
2023-11-03 10:56:20 -06:00
|
|
|
|
|
|
|
import (
|
|
|
|
"database/sql"
|
2023-11-03 20:01:15 -06:00
|
|
|
"os"
|
|
|
|
"path/filepath"
|
2023-11-03 10:56:20 -06:00
|
|
|
"gorm.io/gorm"
|
|
|
|
"gorm.io/driver/sqlite"
|
|
|
|
sqids "github.com/sqids/sqids-go"
|
|
|
|
)
|
|
|
|
|
|
|
|
type Store struct {
|
|
|
|
db *gorm.DB
|
|
|
|
sqids *sqids.Sqids
|
|
|
|
}
|
|
|
|
|
|
|
|
type Message struct {
|
|
|
|
ID uint `gorm:"primaryKey"`
|
|
|
|
ConversationID uint `gorm:"foreignKey:ConversationID"`
|
|
|
|
Conversation Conversation
|
|
|
|
OriginalContent string
|
|
|
|
Role string // 'user' or 'assistant'
|
|
|
|
}
|
|
|
|
|
|
|
|
type Conversation struct {
|
|
|
|
ID uint `gorm:"primaryKey"`
|
|
|
|
ShortName sql.NullString
|
|
|
|
Title string
|
|
|
|
}
|
|
|
|
|
2023-11-03 20:01:15 -06:00
|
|
|
|
|
|
|
func getDataDir() string {
|
|
|
|
var dataDir string;
|
|
|
|
|
|
|
|
xdgDataHome := os.Getenv("XDG_DATA_HOME")
|
|
|
|
if xdgDataHome != "" {
|
|
|
|
dataDir = filepath.Join(xdgDataHome, "lmcli")
|
|
|
|
} else {
|
|
|
|
userHomeDir, _ := os.UserHomeDir()
|
|
|
|
dataDir = filepath.Join(userHomeDir, ".local/share/lmcli")
|
|
|
|
}
|
|
|
|
|
|
|
|
os.MkdirAll(dataDir, 0755)
|
|
|
|
return dataDir
|
|
|
|
}
|
2023-11-03 10:56:20 -06:00
|
|
|
|
|
|
|
func InitializeStore() (*Store, error) {
|
2023-11-03 20:01:15 -06:00
|
|
|
databaseFile := filepath.Join(getDataDir(), "conversations.db")
|
|
|
|
db, err := gorm.Open(sqlite.Open(databaseFile), &gorm.Config{})
|
2023-11-03 10:56:20 -06:00
|
|
|
if err != nil {
|
|
|
|
return nil, err
|
|
|
|
}
|
|
|
|
|
|
|
|
models := []any{
|
|
|
|
&Conversation{},
|
|
|
|
&Message{},
|
|
|
|
}
|
|
|
|
|
|
|
|
for _, x := range(models) {
|
|
|
|
err := db.AutoMigrate(x)
|
|
|
|
if err != nil {
|
|
|
|
return nil, err
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
_sqids, _ := sqids.New(sqids.Options{
|
|
|
|
MinLength: 4,
|
|
|
|
})
|
|
|
|
|
|
|
|
return &Store{db: db, sqids: _sqids}, nil
|
|
|
|
}
|
|
|
|
|
|
|
|
func (s *Store) SaveConversation(conversation *Conversation) error {
|
|
|
|
err := s.db.Save(&conversation).Error
|
|
|
|
if err != nil {
|
|
|
|
return err
|
|
|
|
}
|
|
|
|
|
|
|
|
if !conversation.ShortName.Valid {
|
|
|
|
shortName, _ := s.sqids.Encode([]uint64{ uint64(conversation.ID) })
|
|
|
|
conversation.ShortName = sql.NullString{String: shortName, Valid: true}
|
|
|
|
err = s.db.Save(&conversation).Error
|
|
|
|
}
|
|
|
|
|
|
|
|
return err
|
|
|
|
}
|
|
|
|
|
|
|
|
func (s *Store) SaveMessage(message *Message) error {
|
|
|
|
return s.db.Create(message).Error
|
|
|
|
}
|
|
|
|
|
|
|
|
func (s *Store) GetConversations() ([]Conversation, error) {
|
|
|
|
var conversations []Conversation
|
|
|
|
err := s.db.Find(&conversations).Error
|
|
|
|
return conversations, err
|
|
|
|
}
|
|
|
|
|
|
|
|
func (s *Store) GetMessages(conversation *Conversation) ([]Message, error) {
|
|
|
|
var messages []Message
|
|
|
|
err := s.db.Where("conversation_id = ?", conversation.ID).Find(&messages).Error
|
|
|
|
return messages, err
|
|
|
|
}
|