99 lines
2.2 KiB
Go
99 lines
2.2 KiB
Go
|
package conversation
|
||
|
|
||
|
import (
|
||
|
"database/sql"
|
||
|
"database/sql/driver"
|
||
|
"encoding/json"
|
||
|
"fmt"
|
||
|
"time"
|
||
|
|
||
|
"git.mlow.ca/mlow/lmcli/pkg/api"
|
||
|
)
|
||
|
|
||
|
type Conversation struct {
|
||
|
ID uint `gorm:"primaryKey"`
|
||
|
ShortName sql.NullString
|
||
|
Title string
|
||
|
SelectedRootID *uint
|
||
|
SelectedRoot *Message `gorm:"foreignKey:SelectedRootID"`
|
||
|
RootMessages []Message `gorm:"-:all"`
|
||
|
}
|
||
|
|
||
|
type MessageMeta struct {
|
||
|
GenerationProvider *string `json:"generation_provider,omitempty"`
|
||
|
GenerationModel *string `json:"generation_model,omitempty"`
|
||
|
}
|
||
|
|
||
|
type Message struct {
|
||
|
ID uint `gorm:"primaryKey"`
|
||
|
CreatedAt time.Time
|
||
|
Metadata MessageMeta
|
||
|
|
||
|
ConversationID *uint `gorm:"index"`
|
||
|
Conversation *Conversation `gorm:"foreignKey:ConversationID"`
|
||
|
ParentID *uint
|
||
|
Parent *Message `gorm:"foreignKey:ParentID"`
|
||
|
Replies []Message `gorm:"foreignKey:ParentID"`
|
||
|
SelectedReplyID *uint
|
||
|
SelectedReply *Message `gorm:"foreignKey:SelectedReplyID"`
|
||
|
|
||
|
Role api.MessageRole
|
||
|
Content string
|
||
|
ToolCalls ToolCalls // a json array of tool calls (from the model)
|
||
|
ToolResults ToolResults // a json array of tool results
|
||
|
}
|
||
|
|
||
|
func (m *MessageMeta) Scan(value interface{}) error {
|
||
|
return json.Unmarshal(value.([]byte), m)
|
||
|
}
|
||
|
|
||
|
func (m MessageMeta) Value() (driver.Value, error) {
|
||
|
return json.Marshal(m)
|
||
|
}
|
||
|
|
||
|
type ToolCalls []api.ToolCall
|
||
|
|
||
|
func (tc *ToolCalls) Scan(value any) (err error) {
|
||
|
s := value.(string)
|
||
|
if value == nil || s == "" {
|
||
|
*tc = nil
|
||
|
return
|
||
|
}
|
||
|
err = json.Unmarshal([]byte(s), tc)
|
||
|
return
|
||
|
}
|
||
|
|
||
|
func (tc ToolCalls) Value() (driver.Value, error) {
|
||
|
if len(tc) == 0 {
|
||
|
return "", nil
|
||
|
}
|
||
|
jsonBytes, err := json.Marshal(tc)
|
||
|
if err != nil {
|
||
|
return "", fmt.Errorf("Could not marshal ToolCalls to JSON: %v\n", err)
|
||
|
}
|
||
|
return string(jsonBytes), nil
|
||
|
}
|
||
|
|
||
|
type ToolResults []api.ToolResult
|
||
|
|
||
|
func (tr *ToolResults) Scan(value any) (err error) {
|
||
|
s := value.(string)
|
||
|
if value == nil || s == "" {
|
||
|
*tr = nil
|
||
|
return
|
||
|
}
|
||
|
err = json.Unmarshal([]byte(s), tr)
|
||
|
return
|
||
|
}
|
||
|
|
||
|
func (tr ToolResults) Value() (driver.Value, error) {
|
||
|
if len(tr) == 0 {
|
||
|
return "", nil
|
||
|
}
|
||
|
jsonBytes, err := json.Marshal([]api.ToolResult(tr))
|
||
|
if err != nil {
|
||
|
return "", fmt.Errorf("Could not marshal ToolResults to JSON: %v\n", err)
|
||
|
}
|
||
|
return string(jsonBytes), nil
|
||
|
}
|