2024-10-19 20:38:42 -06:00
|
|
|
package conversation
|
|
|
|
|
|
|
|
import (
|
|
|
|
"database/sql"
|
|
|
|
"database/sql/driver"
|
|
|
|
"encoding/json"
|
|
|
|
"fmt"
|
|
|
|
"time"
|
|
|
|
|
|
|
|
"git.mlow.ca/mlow/lmcli/pkg/api"
|
|
|
|
)
|
|
|
|
|
|
|
|
type Conversation struct {
|
|
|
|
ID uint `gorm:"primaryKey"`
|
|
|
|
ShortName sql.NullString
|
|
|
|
Title string
|
|
|
|
SelectedRootID *uint
|
|
|
|
SelectedRoot *Message `gorm:"foreignKey:SelectedRootID"`
|
|
|
|
RootMessages []Message `gorm:"-:all"`
|
2024-10-21 09:33:20 -06:00
|
|
|
LastMessageAt time.Time
|
2024-10-19 20:38:42 -06:00
|
|
|
}
|
|
|
|
|
|
|
|
type MessageMeta struct {
|
|
|
|
GenerationProvider *string `json:"generation_provider,omitempty"`
|
|
|
|
GenerationModel *string `json:"generation_model,omitempty"`
|
|
|
|
}
|
|
|
|
|
|
|
|
type Message struct {
|
|
|
|
ID uint `gorm:"primaryKey"`
|
|
|
|
CreatedAt time.Time
|
|
|
|
Metadata MessageMeta
|
|
|
|
|
|
|
|
ConversationID *uint `gorm:"index"`
|
|
|
|
Conversation *Conversation `gorm:"foreignKey:ConversationID"`
|
|
|
|
ParentID *uint
|
|
|
|
Parent *Message `gorm:"foreignKey:ParentID"`
|
|
|
|
Replies []Message `gorm:"foreignKey:ParentID"`
|
|
|
|
SelectedReplyID *uint
|
|
|
|
SelectedReply *Message `gorm:"foreignKey:SelectedReplyID"`
|
|
|
|
|
|
|
|
Role api.MessageRole
|
|
|
|
Content string
|
|
|
|
ToolCalls ToolCalls // a json array of tool calls (from the model)
|
|
|
|
ToolResults ToolResults // a json array of tool results
|
|
|
|
}
|
|
|
|
|
|
|
|
func (m *MessageMeta) Scan(value interface{}) error {
|
|
|
|
return json.Unmarshal(value.([]byte), m)
|
|
|
|
}
|
|
|
|
|
|
|
|
func (m MessageMeta) Value() (driver.Value, error) {
|
|
|
|
return json.Marshal(m)
|
|
|
|
}
|
|
|
|
|
|
|
|
type ToolCalls []api.ToolCall
|
|
|
|
|
|
|
|
func (tc *ToolCalls) Scan(value any) (err error) {
|
|
|
|
s := value.(string)
|
|
|
|
if value == nil || s == "" {
|
|
|
|
*tc = nil
|
|
|
|
return
|
|
|
|
}
|
|
|
|
err = json.Unmarshal([]byte(s), tc)
|
|
|
|
return
|
|
|
|
}
|
|
|
|
|
|
|
|
func (tc ToolCalls) Value() (driver.Value, error) {
|
|
|
|
if len(tc) == 0 {
|
|
|
|
return "", nil
|
|
|
|
}
|
|
|
|
jsonBytes, err := json.Marshal(tc)
|
|
|
|
if err != nil {
|
|
|
|
return "", fmt.Errorf("Could not marshal ToolCalls to JSON: %v\n", err)
|
|
|
|
}
|
|
|
|
return string(jsonBytes), nil
|
|
|
|
}
|
|
|
|
|
|
|
|
type ToolResults []api.ToolResult
|
|
|
|
|
|
|
|
func (tr *ToolResults) Scan(value any) (err error) {
|
|
|
|
s := value.(string)
|
|
|
|
if value == nil || s == "" {
|
|
|
|
*tr = nil
|
|
|
|
return
|
|
|
|
}
|
|
|
|
err = json.Unmarshal([]byte(s), tr)
|
|
|
|
return
|
|
|
|
}
|
|
|
|
|
|
|
|
func (tr ToolResults) Value() (driver.Value, error) {
|
|
|
|
if len(tr) == 0 {
|
|
|
|
return "", nil
|
|
|
|
}
|
|
|
|
jsonBytes, err := json.Marshal([]api.ToolResult(tr))
|
|
|
|
if err != nil {
|
|
|
|
return "", fmt.Errorf("Could not marshal ToolResults to JSON: %v\n", err)
|
|
|
|
}
|
|
|
|
return string(jsonBytes), nil
|
|
|
|
}
|