Large refactor - it compiles!
This refactor splits out all conversation concerns into a new `conversation` package. There is now a split between `conversation` and `api`s representation of `Message`, the latter storing the minimum information required for interaction with LLM providers. There is necessary conversation between the two when making LLM calls.
This commit is contained in:
parent
2ea8a73eb5
commit
0384c7cb66
118
pkg/api/api.go
Normal file
118
pkg/api/api.go
Normal file
@ -0,0 +1,118 @@
|
|||||||
|
package api
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
|
)
|
||||||
|
|
||||||
|
type MessageRole string
|
||||||
|
|
||||||
|
const (
|
||||||
|
MessageRoleSystem MessageRole = "system"
|
||||||
|
MessageRoleUser MessageRole = "user"
|
||||||
|
MessageRoleAssistant MessageRole = "assistant"
|
||||||
|
MessageRoleToolCall MessageRole = "tool_call"
|
||||||
|
MessageRoleToolResult MessageRole = "tool_result"
|
||||||
|
)
|
||||||
|
|
||||||
|
type Message struct {
|
||||||
|
Content string // TODO: support multi-part messages
|
||||||
|
Role MessageRole
|
||||||
|
ToolCalls []ToolCall
|
||||||
|
ToolResults []ToolResult
|
||||||
|
}
|
||||||
|
|
||||||
|
type ToolSpec struct {
|
||||||
|
Name string
|
||||||
|
Description string
|
||||||
|
Parameters []ToolParameter
|
||||||
|
Impl func(*ToolSpec, map[string]interface{}) (string, error)
|
||||||
|
}
|
||||||
|
|
||||||
|
type ToolParameter struct {
|
||||||
|
Name string `json:"name"`
|
||||||
|
Type string `json:"type"` // "string", "integer", "boolean"
|
||||||
|
Required bool `json:"required"`
|
||||||
|
Description string `json:"description"`
|
||||||
|
Enum []string `json:"enum,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type ToolCall struct {
|
||||||
|
ID string `json:"id" yaml:"-"`
|
||||||
|
Name string `json:"name" yaml:"tool"`
|
||||||
|
Parameters map[string]interface{} `json:"parameters" yaml:"parameters"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type ToolResult struct {
|
||||||
|
ToolCallID string `json:"toolCallID" yaml:"-"`
|
||||||
|
ToolName string `json:"toolName,omitempty" yaml:"tool"`
|
||||||
|
Result string `json:"result,omitempty" yaml:"result"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewMessageWithAssistant(content string) *Message {
|
||||||
|
return &Message{
|
||||||
|
Role: MessageRoleAssistant,
|
||||||
|
Content: content,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewMessageWithToolCalls(content string, toolCalls []ToolCall) *Message {
|
||||||
|
return &Message{
|
||||||
|
Role: MessageRoleToolCall,
|
||||||
|
Content: content,
|
||||||
|
ToolCalls: toolCalls,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m MessageRole) IsAssistant() bool {
|
||||||
|
switch m {
|
||||||
|
case MessageRoleAssistant, MessageRoleToolCall:
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m MessageRole) IsSystem() bool {
|
||||||
|
switch m {
|
||||||
|
case MessageRoleSystem:
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
// FriendlyRole returns a human friendly signifier for the message's role.
|
||||||
|
func (m MessageRole) FriendlyRole() string {
|
||||||
|
switch m {
|
||||||
|
case MessageRoleUser:
|
||||||
|
return "You"
|
||||||
|
case MessageRoleSystem:
|
||||||
|
return "System"
|
||||||
|
case MessageRoleAssistant:
|
||||||
|
return "Assistant"
|
||||||
|
case MessageRoleToolCall:
|
||||||
|
return "Tool Call"
|
||||||
|
case MessageRoleToolResult:
|
||||||
|
return "Tool Result"
|
||||||
|
default:
|
||||||
|
return string(m)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TODO: remove this
|
||||||
|
type CallResult struct {
|
||||||
|
Message string `json:"message"`
|
||||||
|
Result any `json:"result,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r CallResult) ToJson() (string, error) {
|
||||||
|
if r.Message == "" {
|
||||||
|
// When message not supplied, assume success
|
||||||
|
r.Message = "success"
|
||||||
|
}
|
||||||
|
|
||||||
|
jsonBytes, err := json.Marshal(r)
|
||||||
|
if err != nil {
|
||||||
|
return "", fmt.Errorf("Could not marshal CallResult to JSON: %v\n", err)
|
||||||
|
}
|
||||||
|
return string(jsonBytes), nil
|
||||||
|
}
|
@ -1,106 +0,0 @@
|
|||||||
package api
|
|
||||||
|
|
||||||
import (
|
|
||||||
"database/sql"
|
|
||||||
"database/sql/driver"
|
|
||||||
"encoding/json"
|
|
||||||
"time"
|
|
||||||
)
|
|
||||||
|
|
||||||
type Conversation struct {
|
|
||||||
ID uint `gorm:"primaryKey"`
|
|
||||||
ShortName sql.NullString
|
|
||||||
Title string
|
|
||||||
SelectedRootID *uint
|
|
||||||
SelectedRoot *Message `gorm:"foreignKey:SelectedRootID"`
|
|
||||||
}
|
|
||||||
|
|
||||||
type MessageRole string
|
|
||||||
|
|
||||||
const (
|
|
||||||
MessageRoleSystem MessageRole = "system"
|
|
||||||
MessageRoleUser MessageRole = "user"
|
|
||||||
MessageRoleAssistant MessageRole = "assistant"
|
|
||||||
MessageRoleToolCall MessageRole = "tool_call"
|
|
||||||
MessageRoleToolResult MessageRole = "tool_result"
|
|
||||||
)
|
|
||||||
|
|
||||||
type MessageMeta struct {
|
|
||||||
GenerationProvider *string `json:"generation_provider,omitempty"`
|
|
||||||
GenerationModel *string `json:"generation_model,omitempty"`
|
|
||||||
}
|
|
||||||
|
|
||||||
type Message struct {
|
|
||||||
ID uint `gorm:"primaryKey"`
|
|
||||||
CreatedAt time.Time
|
|
||||||
Metadata MessageMeta
|
|
||||||
|
|
||||||
ConversationID *uint `gorm:"index"`
|
|
||||||
Conversation *Conversation `gorm:"foreignKey:ConversationID"`
|
|
||||||
ParentID *uint
|
|
||||||
Parent *Message `gorm:"foreignKey:ParentID"`
|
|
||||||
Replies []Message `gorm:"foreignKey:ParentID"`
|
|
||||||
SelectedReplyID *uint
|
|
||||||
SelectedReply *Message `gorm:"foreignKey:SelectedReplyID"`
|
|
||||||
|
|
||||||
Role MessageRole
|
|
||||||
Content string
|
|
||||||
ToolCalls ToolCalls // a json array of tool calls (from the model)
|
|
||||||
ToolResults ToolResults // a json array of tool results
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m *MessageMeta) Scan(value interface{}) error {
|
|
||||||
return json.Unmarshal(value.([]byte), m)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m MessageMeta) Value() (driver.Value, error) {
|
|
||||||
return json.Marshal(m)
|
|
||||||
}
|
|
||||||
|
|
||||||
func ApplySystemPrompt(m []Message, system string, force bool) []Message {
|
|
||||||
if len(m) > 0 && m[0].Role == MessageRoleSystem {
|
|
||||||
if force {
|
|
||||||
m[0].Content = system
|
|
||||||
}
|
|
||||||
return m
|
|
||||||
} else {
|
|
||||||
return append([]Message{{
|
|
||||||
Role: MessageRoleSystem,
|
|
||||||
Content: system,
|
|
||||||
}}, m...)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m MessageRole) IsAssistant() bool {
|
|
||||||
switch m {
|
|
||||||
case MessageRoleAssistant, MessageRoleToolCall:
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m MessageRole) IsSystem() bool {
|
|
||||||
switch m {
|
|
||||||
case MessageRoleSystem:
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
|
|
||||||
// FriendlyRole returns a human friendly signifier for the message's role.
|
|
||||||
func (m MessageRole) FriendlyRole() string {
|
|
||||||
switch m {
|
|
||||||
case MessageRoleUser:
|
|
||||||
return "You"
|
|
||||||
case MessageRoleSystem:
|
|
||||||
return "System"
|
|
||||||
case MessageRoleAssistant:
|
|
||||||
return "Assistant"
|
|
||||||
case MessageRoleToolCall:
|
|
||||||
return "Tool Call"
|
|
||||||
case MessageRoleToolResult:
|
|
||||||
return "Tool Result"
|
|
||||||
default:
|
|
||||||
return string(m)
|
|
||||||
}
|
|
||||||
}
|
|
@ -1,98 +0,0 @@
|
|||||||
package api
|
|
||||||
|
|
||||||
import (
|
|
||||||
"database/sql/driver"
|
|
||||||
"encoding/json"
|
|
||||||
"fmt"
|
|
||||||
)
|
|
||||||
|
|
||||||
type ToolSpec struct {
|
|
||||||
Name string
|
|
||||||
Description string
|
|
||||||
Parameters []ToolParameter
|
|
||||||
Impl func(*ToolSpec, map[string]interface{}) (string, error)
|
|
||||||
}
|
|
||||||
|
|
||||||
type ToolParameter struct {
|
|
||||||
Name string `json:"name"`
|
|
||||||
Type string `json:"type"` // "string", "integer", "boolean"
|
|
||||||
Required bool `json:"required"`
|
|
||||||
Description string `json:"description"`
|
|
||||||
Enum []string `json:"enum,omitempty"`
|
|
||||||
}
|
|
||||||
|
|
||||||
type ToolCall struct {
|
|
||||||
ID string `json:"id" yaml:"-"`
|
|
||||||
Name string `json:"name" yaml:"tool"`
|
|
||||||
Parameters map[string]interface{} `json:"parameters" yaml:"parameters"`
|
|
||||||
}
|
|
||||||
|
|
||||||
type ToolResult struct {
|
|
||||||
ToolCallID string `json:"toolCallID" yaml:"-"`
|
|
||||||
ToolName string `json:"toolName,omitempty" yaml:"tool"`
|
|
||||||
Result string `json:"result,omitempty" yaml:"result"`
|
|
||||||
}
|
|
||||||
|
|
||||||
type ToolCalls []ToolCall
|
|
||||||
|
|
||||||
func (tc *ToolCalls) Scan(value any) (err error) {
|
|
||||||
s := value.(string)
|
|
||||||
if value == nil || s == "" {
|
|
||||||
*tc = nil
|
|
||||||
return
|
|
||||||
}
|
|
||||||
err = json.Unmarshal([]byte(s), tc)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
func (tc ToolCalls) Value() (driver.Value, error) {
|
|
||||||
if len(tc) == 0 {
|
|
||||||
return "", nil
|
|
||||||
}
|
|
||||||
jsonBytes, err := json.Marshal(tc)
|
|
||||||
if err != nil {
|
|
||||||
return "", fmt.Errorf("Could not marshal ToolCalls to JSON: %v\n", err)
|
|
||||||
}
|
|
||||||
return string(jsonBytes), nil
|
|
||||||
}
|
|
||||||
|
|
||||||
type ToolResults []ToolResult
|
|
||||||
|
|
||||||
func (tr *ToolResults) Scan(value any) (err error) {
|
|
||||||
s := value.(string)
|
|
||||||
if value == nil || s == "" {
|
|
||||||
*tr = nil
|
|
||||||
return
|
|
||||||
}
|
|
||||||
err = json.Unmarshal([]byte(s), tr)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
func (tr ToolResults) Value() (driver.Value, error) {
|
|
||||||
if len(tr) == 0 {
|
|
||||||
return "", nil
|
|
||||||
}
|
|
||||||
jsonBytes, err := json.Marshal([]ToolResult(tr))
|
|
||||||
if err != nil {
|
|
||||||
return "", fmt.Errorf("Could not marshal ToolResults to JSON: %v\n", err)
|
|
||||||
}
|
|
||||||
return string(jsonBytes), nil
|
|
||||||
}
|
|
||||||
|
|
||||||
type CallResult struct {
|
|
||||||
Message string `json:"message"`
|
|
||||||
Result any `json:"result,omitempty"`
|
|
||||||
}
|
|
||||||
|
|
||||||
func (r CallResult) ToJson() (string, error) {
|
|
||||||
if r.Message == "" {
|
|
||||||
// When message not supplied, assume success
|
|
||||||
r.Message = "success"
|
|
||||||
}
|
|
||||||
|
|
||||||
jsonBytes, err := json.Marshal(r)
|
|
||||||
if err != nil {
|
|
||||||
return "", fmt.Errorf("Could not marshal CallResult to JSON: %v\n", err)
|
|
||||||
}
|
|
||||||
return string(jsonBytes), nil
|
|
||||||
}
|
|
@ -54,7 +54,7 @@ func ChatCmd(ctx *lmcli.Context) *cobra.Command {
|
|||||||
if len(args) != 0 {
|
if len(args) != 0 {
|
||||||
return nil, compMode
|
return nil, compMode
|
||||||
}
|
}
|
||||||
return ctx.Store.ConversationShortNameCompletions(toComplete), compMode
|
return ctx.Conversations.ConversationShortNameCompletions(toComplete), compMode
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -27,7 +27,7 @@ func CloneCmd(ctx *lmcli.Context) *cobra.Command {
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
clone, messageCnt, err := ctx.Store.CloneConversation(*toClone)
|
clone, messageCnt, err := ctx.Conversations.CloneConversation(*toClone)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("Failed to clone conversation: %v", err)
|
return fmt.Errorf("Failed to clone conversation: %v", err)
|
||||||
}
|
}
|
||||||
@ -40,7 +40,7 @@ func CloneCmd(ctx *lmcli.Context) *cobra.Command {
|
|||||||
if len(args) != 0 {
|
if len(args) != 0 {
|
||||||
return nil, compMode
|
return nil, compMode
|
||||||
}
|
}
|
||||||
return ctx.Store.ConversationShortNameCompletions(toComplete), compMode
|
return ctx.Conversations.ConversationShortNameCompletions(toComplete), compMode
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
return cmd
|
return cmd
|
||||||
|
@ -29,9 +29,9 @@ func ContinueCmd(ctx *lmcli.Context) *cobra.Command {
|
|||||||
}
|
}
|
||||||
|
|
||||||
shortName := args[0]
|
shortName := args[0]
|
||||||
conversation := cmdutil.LookupConversation(ctx, shortName)
|
c := cmdutil.LookupConversation(ctx, shortName)
|
||||||
|
|
||||||
messages, err := ctx.Store.PathToLeaf(conversation.SelectedRoot)
|
messages, err := ctx.Conversations.PathToLeaf(c.SelectedRoot)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("could not retrieve conversation messages: %v", err)
|
return fmt.Errorf("could not retrieve conversation messages: %v", err)
|
||||||
}
|
}
|
||||||
@ -58,7 +58,7 @@ func ContinueCmd(ctx *lmcli.Context) *cobra.Command {
|
|||||||
lastMessage.Content += strings.TrimRight(continuedOutput.Content, "\n\t ")
|
lastMessage.Content += strings.TrimRight(continuedOutput.Content, "\n\t ")
|
||||||
|
|
||||||
// Update the original message
|
// Update the original message
|
||||||
err = ctx.Store.UpdateMessage(lastMessage)
|
err = ctx.Conversations.UpdateMessage(lastMessage)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("could not update the last message: %v", err)
|
return fmt.Errorf("could not update the last message: %v", err)
|
||||||
}
|
}
|
||||||
@ -70,7 +70,7 @@ func ContinueCmd(ctx *lmcli.Context) *cobra.Command {
|
|||||||
if len(args) != 0 {
|
if len(args) != 0 {
|
||||||
return nil, compMode
|
return nil, compMode
|
||||||
}
|
}
|
||||||
return ctx.Store.ConversationShortNameCompletions(toComplete), compMode
|
return ctx.Conversations.ConversationShortNameCompletions(toComplete), compMode
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
applyGenerationFlags(ctx, cmd)
|
applyGenerationFlags(ctx, cmd)
|
||||||
|
@ -22,11 +22,11 @@ func EditCmd(ctx *lmcli.Context) *cobra.Command {
|
|||||||
},
|
},
|
||||||
RunE: func(cmd *cobra.Command, args []string) error {
|
RunE: func(cmd *cobra.Command, args []string) error {
|
||||||
shortName := args[0]
|
shortName := args[0]
|
||||||
conversation := cmdutil.LookupConversation(ctx, shortName)
|
c := cmdutil.LookupConversation(ctx, shortName)
|
||||||
|
|
||||||
messages, err := ctx.Store.PathToLeaf(conversation.SelectedRoot)
|
messages, err := ctx.Conversations.PathToLeaf(c.SelectedRoot)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("Could not retrieve messages for conversation: %s", conversation.Title)
|
return fmt.Errorf("Could not retrieve messages for conversation: %s", c.Title)
|
||||||
}
|
}
|
||||||
|
|
||||||
offset, _ := cmd.Flags().GetInt("offset")
|
offset, _ := cmd.Flags().GetInt("offset")
|
||||||
@ -62,11 +62,11 @@ func EditCmd(ctx *lmcli.Context) *cobra.Command {
|
|||||||
// Update the message in-place
|
// Update the message in-place
|
||||||
inplace, _ := cmd.Flags().GetBool("in-place")
|
inplace, _ := cmd.Flags().GetBool("in-place")
|
||||||
if inplace {
|
if inplace {
|
||||||
return ctx.Store.UpdateMessage(&toEdit)
|
return ctx.Conversations.UpdateMessage(&toEdit)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Otherwise, create a branch for the edited message
|
// Otherwise, create a branch for the edited message
|
||||||
message, _, err := ctx.Store.CloneBranch(toEdit)
|
message, _, err := ctx.Conversations.CloneBranch(toEdit)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
@ -74,11 +74,11 @@ func EditCmd(ctx *lmcli.Context) *cobra.Command {
|
|||||||
if desiredIdx > 0 {
|
if desiredIdx > 0 {
|
||||||
// update selected reply
|
// update selected reply
|
||||||
messages[desiredIdx-1].SelectedReply = message
|
messages[desiredIdx-1].SelectedReply = message
|
||||||
err = ctx.Store.UpdateMessage(&messages[desiredIdx-1])
|
err = ctx.Conversations.UpdateMessage(&messages[desiredIdx-1])
|
||||||
} else {
|
} else {
|
||||||
// update selected root
|
// update selected root
|
||||||
conversation.SelectedRoot = message
|
c.SelectedRoot = message
|
||||||
err = ctx.Store.UpdateConversation(conversation)
|
err = ctx.Conversations.UpdateConversation(c)
|
||||||
}
|
}
|
||||||
return err
|
return err
|
||||||
},
|
},
|
||||||
@ -87,7 +87,7 @@ func EditCmd(ctx *lmcli.Context) *cobra.Command {
|
|||||||
if len(args) != 0 {
|
if len(args) != 0 {
|
||||||
return nil, compMode
|
return nil, compMode
|
||||||
}
|
}
|
||||||
return ctx.Store.ConversationShortNameCompletions(toComplete), compMode
|
return ctx.Conversations.ConversationShortNameCompletions(toComplete), compMode
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -20,7 +20,7 @@ func ListCmd(ctx *lmcli.Context) *cobra.Command {
|
|||||||
Short: "List conversations",
|
Short: "List conversations",
|
||||||
Long: `List conversations in order of recent activity`,
|
Long: `List conversations in order of recent activity`,
|
||||||
RunE: func(cmd *cobra.Command, args []string) error {
|
RunE: func(cmd *cobra.Command, args []string) error {
|
||||||
messages, err := ctx.Store.LatestConversationMessages()
|
messages, err := ctx.Conversations.LatestConversationMessages()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("Could not fetch conversations: %v", err)
|
return fmt.Errorf("Could not fetch conversations: %v", err)
|
||||||
}
|
}
|
||||||
|
@ -5,6 +5,7 @@ import (
|
|||||||
|
|
||||||
"git.mlow.ca/mlow/lmcli/pkg/api"
|
"git.mlow.ca/mlow/lmcli/pkg/api"
|
||||||
cmdutil "git.mlow.ca/mlow/lmcli/pkg/cmd/util"
|
cmdutil "git.mlow.ca/mlow/lmcli/pkg/cmd/util"
|
||||||
|
"git.mlow.ca/mlow/lmcli/pkg/conversation"
|
||||||
"git.mlow.ca/mlow/lmcli/pkg/lmcli"
|
"git.mlow.ca/mlow/lmcli/pkg/lmcli"
|
||||||
"github.com/spf13/cobra"
|
"github.com/spf13/cobra"
|
||||||
)
|
)
|
||||||
@ -25,12 +26,12 @@ func NewCmd(ctx *lmcli.Context) *cobra.Command {
|
|||||||
return fmt.Errorf("No message was provided.")
|
return fmt.Errorf("No message was provided.")
|
||||||
}
|
}
|
||||||
|
|
||||||
messages := []api.Message{{
|
messages := []conversation.Message{{
|
||||||
Role: api.MessageRoleUser,
|
Role: api.MessageRoleUser,
|
||||||
Content: input,
|
Content: input,
|
||||||
}}
|
}}
|
||||||
|
|
||||||
conversation, messages, err := ctx.Store.StartConversation(messages...)
|
conversation, messages, err := ctx.Conversations.StartConversation(messages...)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("Could not start a new conversation: %v", err)
|
return fmt.Errorf("Could not start a new conversation: %v", err)
|
||||||
}
|
}
|
||||||
@ -43,7 +44,7 @@ func NewCmd(ctx *lmcli.Context) *cobra.Command {
|
|||||||
}
|
}
|
||||||
|
|
||||||
conversation.Title = title
|
conversation.Title = title
|
||||||
err = ctx.Store.UpdateConversation(conversation)
|
err = ctx.Conversations.UpdateConversation(conversation)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
lmcli.Warn("Could not save conversation title: %v\n", err)
|
lmcli.Warn("Could not save conversation title: %v\n", err)
|
||||||
}
|
}
|
||||||
|
@ -5,6 +5,7 @@ import (
|
|||||||
|
|
||||||
"git.mlow.ca/mlow/lmcli/pkg/api"
|
"git.mlow.ca/mlow/lmcli/pkg/api"
|
||||||
cmdutil "git.mlow.ca/mlow/lmcli/pkg/cmd/util"
|
cmdutil "git.mlow.ca/mlow/lmcli/pkg/cmd/util"
|
||||||
|
"git.mlow.ca/mlow/lmcli/pkg/conversation"
|
||||||
"git.mlow.ca/mlow/lmcli/pkg/lmcli"
|
"git.mlow.ca/mlow/lmcli/pkg/lmcli"
|
||||||
"github.com/spf13/cobra"
|
"github.com/spf13/cobra"
|
||||||
)
|
)
|
||||||
@ -25,7 +26,7 @@ func PromptCmd(ctx *lmcli.Context) *cobra.Command {
|
|||||||
return fmt.Errorf("No message was provided.")
|
return fmt.Errorf("No message was provided.")
|
||||||
}
|
}
|
||||||
|
|
||||||
messages := []api.Message{{
|
messages := []conversation.Message{{
|
||||||
Role: api.MessageRoleUser,
|
Role: api.MessageRoleUser,
|
||||||
Content: input,
|
Content: input,
|
||||||
}}
|
}}
|
||||||
|
@ -4,8 +4,8 @@ import (
|
|||||||
"fmt"
|
"fmt"
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
"git.mlow.ca/mlow/lmcli/pkg/api"
|
|
||||||
cmdutil "git.mlow.ca/mlow/lmcli/pkg/cmd/util"
|
cmdutil "git.mlow.ca/mlow/lmcli/pkg/cmd/util"
|
||||||
|
"git.mlow.ca/mlow/lmcli/pkg/conversation"
|
||||||
"git.mlow.ca/mlow/lmcli/pkg/lmcli"
|
"git.mlow.ca/mlow/lmcli/pkg/lmcli"
|
||||||
"github.com/spf13/cobra"
|
"github.com/spf13/cobra"
|
||||||
)
|
)
|
||||||
@ -23,14 +23,14 @@ func RemoveCmd(ctx *lmcli.Context) *cobra.Command {
|
|||||||
return nil
|
return nil
|
||||||
},
|
},
|
||||||
RunE: func(cmd *cobra.Command, args []string) error {
|
RunE: func(cmd *cobra.Command, args []string) error {
|
||||||
var toRemove []*api.Conversation
|
var toRemove []*conversation.Conversation
|
||||||
for _, shortName := range args {
|
for _, shortName := range args {
|
||||||
conversation := cmdutil.LookupConversation(ctx, shortName)
|
conversation := cmdutil.LookupConversation(ctx, shortName)
|
||||||
toRemove = append(toRemove, conversation)
|
toRemove = append(toRemove, conversation)
|
||||||
}
|
}
|
||||||
var errors []error
|
var errors []error
|
||||||
for _, c := range toRemove {
|
for _, c := range toRemove {
|
||||||
err := ctx.Store.DeleteConversation(c)
|
err := ctx.Conversations.DeleteConversation(c)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
errors = append(errors, fmt.Errorf("Could not remove conversation %s: %v", c.ShortName.String, err))
|
errors = append(errors, fmt.Errorf("Could not remove conversation %s: %v", c.ShortName.String, err))
|
||||||
}
|
}
|
||||||
@ -44,7 +44,7 @@ func RemoveCmd(ctx *lmcli.Context) *cobra.Command {
|
|||||||
compMode := cobra.ShellCompDirectiveNoFileComp
|
compMode := cobra.ShellCompDirectiveNoFileComp
|
||||||
var completions []string
|
var completions []string
|
||||||
outer:
|
outer:
|
||||||
for _, completion := range ctx.Store.ConversationShortNameCompletions(toComplete) {
|
for _, completion := range ctx.Conversations.ConversationShortNameCompletions(toComplete) {
|
||||||
parts := strings.Split(completion, "\t")
|
parts := strings.Split(completion, "\t")
|
||||||
for _, arg := range args {
|
for _, arg := range args {
|
||||||
if parts[0] == arg {
|
if parts[0] == arg {
|
||||||
|
@ -30,7 +30,7 @@ func RenameCmd(ctx *lmcli.Context) *cobra.Command {
|
|||||||
|
|
||||||
generate, _ := cmd.Flags().GetBool("generate")
|
generate, _ := cmd.Flags().GetBool("generate")
|
||||||
if generate {
|
if generate {
|
||||||
messages, err := ctx.Store.PathToLeaf(conversation.SelectedRoot)
|
messages, err := ctx.Conversations.PathToLeaf(conversation.SelectedRoot)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("Could not retrieve conversation messages: %v", err)
|
return fmt.Errorf("Could not retrieve conversation messages: %v", err)
|
||||||
}
|
}
|
||||||
@ -46,7 +46,7 @@ func RenameCmd(ctx *lmcli.Context) *cobra.Command {
|
|||||||
}
|
}
|
||||||
|
|
||||||
conversation.Title = title
|
conversation.Title = title
|
||||||
err = ctx.Store.UpdateConversation(conversation)
|
err = ctx.Conversations.UpdateConversation(conversation)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
lmcli.Warn("Could not update conversation title: %v\n", err)
|
lmcli.Warn("Could not update conversation title: %v\n", err)
|
||||||
}
|
}
|
||||||
@ -57,7 +57,7 @@ func RenameCmd(ctx *lmcli.Context) *cobra.Command {
|
|||||||
if len(args) != 0 {
|
if len(args) != 0 {
|
||||||
return nil, compMode
|
return nil, compMode
|
||||||
}
|
}
|
||||||
return ctx.Store.ConversationShortNameCompletions(toComplete), compMode
|
return ctx.Conversations.ConversationShortNameCompletions(toComplete), compMode
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -5,6 +5,7 @@ import (
|
|||||||
|
|
||||||
"git.mlow.ca/mlow/lmcli/pkg/api"
|
"git.mlow.ca/mlow/lmcli/pkg/api"
|
||||||
cmdutil "git.mlow.ca/mlow/lmcli/pkg/cmd/util"
|
cmdutil "git.mlow.ca/mlow/lmcli/pkg/cmd/util"
|
||||||
|
"git.mlow.ca/mlow/lmcli/pkg/conversation"
|
||||||
"git.mlow.ca/mlow/lmcli/pkg/lmcli"
|
"git.mlow.ca/mlow/lmcli/pkg/lmcli"
|
||||||
"github.com/spf13/cobra"
|
"github.com/spf13/cobra"
|
||||||
)
|
)
|
||||||
@ -28,14 +29,14 @@ func ReplyCmd(ctx *lmcli.Context) *cobra.Command {
|
|||||||
}
|
}
|
||||||
|
|
||||||
shortName := args[0]
|
shortName := args[0]
|
||||||
conversation := cmdutil.LookupConversation(ctx, shortName)
|
c := cmdutil.LookupConversation(ctx, shortName)
|
||||||
|
|
||||||
reply := inputFromArgsOrEditor(args[1:], "# How would you like to reply?\n", "")
|
reply := inputFromArgsOrEditor(args[1:], "# How would you like to reply?\n", "")
|
||||||
if reply == "" {
|
if reply == "" {
|
||||||
return fmt.Errorf("No reply was provided.")
|
return fmt.Errorf("No reply was provided.")
|
||||||
}
|
}
|
||||||
|
|
||||||
cmdutil.HandleConversationReply(ctx, conversation, true, api.Message{
|
cmdutil.HandleConversationReply(ctx, c, true, conversation.Message{
|
||||||
Role: api.MessageRoleUser,
|
Role: api.MessageRoleUser,
|
||||||
Content: reply,
|
Content: reply,
|
||||||
})
|
})
|
||||||
@ -46,7 +47,7 @@ func ReplyCmd(ctx *lmcli.Context) *cobra.Command {
|
|||||||
if len(args) != 0 {
|
if len(args) != 0 {
|
||||||
return nil, compMode
|
return nil, compMode
|
||||||
}
|
}
|
||||||
return ctx.Store.ConversationShortNameCompletions(toComplete), compMode
|
return ctx.Conversations.ConversationShortNameCompletions(toComplete), compMode
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -28,12 +28,12 @@ func RetryCmd(ctx *lmcli.Context) *cobra.Command {
|
|||||||
}
|
}
|
||||||
|
|
||||||
shortName := args[0]
|
shortName := args[0]
|
||||||
conversation := cmdutil.LookupConversation(ctx, shortName)
|
c := cmdutil.LookupConversation(ctx, shortName)
|
||||||
|
|
||||||
// Load the complete thread from the root message
|
// Load the complete thread from the root message
|
||||||
messages, err := ctx.Store.PathToLeaf(conversation.SelectedRoot)
|
messages, err := ctx.Conversations.PathToLeaf(c.SelectedRoot)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("Could not retrieve messages for conversation: %s", conversation.Title)
|
return fmt.Errorf("Could not retrieve messages for conversation: %s", c.Title)
|
||||||
}
|
}
|
||||||
|
|
||||||
offset, _ := cmd.Flags().GetInt("offset")
|
offset, _ := cmd.Flags().GetInt("offset")
|
||||||
@ -67,7 +67,7 @@ func RetryCmd(ctx *lmcli.Context) *cobra.Command {
|
|||||||
if len(args) != 0 {
|
if len(args) != 0 {
|
||||||
return nil, compMode
|
return nil, compMode
|
||||||
}
|
}
|
||||||
return ctx.Store.ConversationShortNameCompletions(toComplete), compMode
|
return ctx.Conversations.ConversationShortNameCompletions(toComplete), compMode
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -9,7 +9,8 @@ import (
|
|||||||
"time"
|
"time"
|
||||||
|
|
||||||
"git.mlow.ca/mlow/lmcli/pkg/api"
|
"git.mlow.ca/mlow/lmcli/pkg/api"
|
||||||
"git.mlow.ca/mlow/lmcli/pkg/api/provider"
|
"git.mlow.ca/mlow/lmcli/pkg/provider"
|
||||||
|
"git.mlow.ca/mlow/lmcli/pkg/conversation"
|
||||||
"git.mlow.ca/mlow/lmcli/pkg/lmcli"
|
"git.mlow.ca/mlow/lmcli/pkg/lmcli"
|
||||||
"git.mlow.ca/mlow/lmcli/pkg/util"
|
"git.mlow.ca/mlow/lmcli/pkg/util"
|
||||||
"github.com/charmbracelet/lipgloss"
|
"github.com/charmbracelet/lipgloss"
|
||||||
@ -17,7 +18,7 @@ import (
|
|||||||
|
|
||||||
// Prompt prompts the configured the configured model and streams the response
|
// Prompt prompts the configured the configured model and streams the response
|
||||||
// to stdout. Returns all model reply messages.
|
// to stdout. Returns all model reply messages.
|
||||||
func Prompt(ctx *lmcli.Context, messages []api.Message, callback func(api.Message)) (*api.Message, error) {
|
func Prompt(ctx *lmcli.Context, messages []conversation.Message, callback func(conversation.Message)) (*api.Message, error) {
|
||||||
m, _, p, err := ctx.GetModelProvider(*ctx.Config.Defaults.Model, "")
|
m, _, p, err := ctx.GetModelProvider(*ctx.Config.Defaults.Model, "")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
@ -40,7 +41,7 @@ func Prompt(ctx *lmcli.Context, messages []api.Message, callback func(api.Messag
|
|||||||
}
|
}
|
||||||
|
|
||||||
if system != "" {
|
if system != "" {
|
||||||
messages = api.ApplySystemPrompt(messages, system, false)
|
messages = conversation.ApplySystemPrompt(messages, system, false)
|
||||||
}
|
}
|
||||||
|
|
||||||
content := make(chan provider.Chunk)
|
content := make(chan provider.Chunk)
|
||||||
@ -50,7 +51,7 @@ func Prompt(ctx *lmcli.Context, messages []api.Message, callback func(api.Messag
|
|||||||
go ShowDelayedContent(content)
|
go ShowDelayedContent(content)
|
||||||
|
|
||||||
reply, err := p.CreateChatCompletionStream(
|
reply, err := p.CreateChatCompletionStream(
|
||||||
context.Background(), params, messages, content,
|
context.Background(), params, conversation.MessagesToAPI(messages), content,
|
||||||
)
|
)
|
||||||
|
|
||||||
if reply.Content != "" {
|
if reply.Content != "" {
|
||||||
@ -67,8 +68,8 @@ func Prompt(ctx *lmcli.Context, messages []api.Message, callback func(api.Messag
|
|||||||
|
|
||||||
// lookupConversation either returns the conversation found by the
|
// lookupConversation either returns the conversation found by the
|
||||||
// short name or exits the program
|
// short name or exits the program
|
||||||
func LookupConversation(ctx *lmcli.Context, shortName string) *api.Conversation {
|
func LookupConversation(ctx *lmcli.Context, shortName string) *conversation.Conversation {
|
||||||
c, err := ctx.Store.ConversationByShortName(shortName)
|
c, err := ctx.Conversations.FindConversationByShortName(shortName)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
lmcli.Fatal("Could not lookup conversation: %v\n", err)
|
lmcli.Fatal("Could not lookup conversation: %v\n", err)
|
||||||
}
|
}
|
||||||
@ -78,8 +79,8 @@ func LookupConversation(ctx *lmcli.Context, shortName string) *api.Conversation
|
|||||||
return c
|
return c
|
||||||
}
|
}
|
||||||
|
|
||||||
func LookupConversationE(ctx *lmcli.Context, shortName string) (*api.Conversation, error) {
|
func LookupConversationE(ctx *lmcli.Context, shortName string) (*conversation.Conversation, error) {
|
||||||
c, err := ctx.Store.ConversationByShortName(shortName)
|
c, err := ctx.Conversations.FindConversationByShortName(shortName)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("Could not lookup conversation: %v", err)
|
return nil, fmt.Errorf("Could not lookup conversation: %v", err)
|
||||||
}
|
}
|
||||||
@ -89,8 +90,8 @@ func LookupConversationE(ctx *lmcli.Context, shortName string) (*api.Conversatio
|
|||||||
return c, nil
|
return c, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func HandleConversationReply(ctx *lmcli.Context, c *api.Conversation, persist bool, toSend ...api.Message) {
|
func HandleConversationReply(ctx *lmcli.Context, c *conversation.Conversation, persist bool, toSend ...conversation.Message) {
|
||||||
messages, err := ctx.Store.PathToLeaf(c.SelectedRoot)
|
messages, err := ctx.Conversations.PathToLeaf(c.SelectedRoot)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
lmcli.Fatal("Could not load messages: %v\n", err)
|
lmcli.Fatal("Could not load messages: %v\n", err)
|
||||||
}
|
}
|
||||||
@ -99,40 +100,40 @@ func HandleConversationReply(ctx *lmcli.Context, c *api.Conversation, persist bo
|
|||||||
|
|
||||||
// handleConversationReply handles sending messages to an existing
|
// handleConversationReply handles sending messages to an existing
|
||||||
// conversation, optionally persisting both the sent replies and responses.
|
// conversation, optionally persisting both the sent replies and responses.
|
||||||
func HandleReply(ctx *lmcli.Context, to *api.Message, persist bool, messages ...api.Message) {
|
func HandleReply(ctx *lmcli.Context, to *conversation.Message, persist bool, messages ...conversation.Message) {
|
||||||
if to == nil {
|
if to == nil {
|
||||||
lmcli.Fatal("Can't prompt from an empty message.")
|
lmcli.Fatal("Can't prompt from an empty message.")
|
||||||
}
|
}
|
||||||
|
|
||||||
existing, err := ctx.Store.PathToRoot(to)
|
existing, err := ctx.Conversations.PathToRoot(to)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
lmcli.Fatal("Could not load messages: %v\n", err)
|
lmcli.Fatal("Could not load messages: %v\n", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
RenderConversation(ctx, append(existing, messages...), true)
|
RenderConversation(ctx, append(existing, messages...), true)
|
||||||
|
|
||||||
var savedReplies []api.Message
|
var savedReplies []conversation.Message
|
||||||
if persist && len(messages) > 0 {
|
if persist && len(messages) > 0 {
|
||||||
savedReplies, err = ctx.Store.Reply(to, messages...)
|
savedReplies, err = ctx.Conversations.Reply(to, messages...)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
lmcli.Warn("Could not save messages: %v\n", err)
|
lmcli.Warn("Could not save messages: %v\n", err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// render a message header with no contents
|
// render a message header with no contents
|
||||||
RenderMessage(ctx, (&api.Message{Role: api.MessageRoleAssistant}))
|
RenderMessage(ctx, (&conversation.Message{Role: api.MessageRoleAssistant}))
|
||||||
|
|
||||||
var lastSavedMessage *api.Message
|
var lastSavedMessage *conversation.Message
|
||||||
lastSavedMessage = to
|
lastSavedMessage = to
|
||||||
if len(savedReplies) > 0 {
|
if len(savedReplies) > 0 {
|
||||||
lastSavedMessage = &savedReplies[len(savedReplies)-1]
|
lastSavedMessage = &savedReplies[len(savedReplies)-1]
|
||||||
}
|
}
|
||||||
|
|
||||||
replyCallback := func(reply api.Message) {
|
replyCallback := func(reply conversation.Message) {
|
||||||
if !persist {
|
if !persist {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
savedReplies, err = ctx.Store.Reply(lastSavedMessage, reply)
|
savedReplies, err = ctx.Conversations.Reply(lastSavedMessage, reply)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
lmcli.Warn("Could not save reply: %v\n", err)
|
lmcli.Warn("Could not save reply: %v\n", err)
|
||||||
}
|
}
|
||||||
@ -145,7 +146,7 @@ func HandleReply(ctx *lmcli.Context, to *api.Message, persist bool, messages ...
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func FormatForExternalPrompt(messages []api.Message, system bool) string {
|
func FormatForExternalPrompt(messages []conversation.Message, system bool) string {
|
||||||
sb := strings.Builder{}
|
sb := strings.Builder{}
|
||||||
for _, message := range messages {
|
for _, message := range messages {
|
||||||
if message.Content == "" {
|
if message.Content == "" {
|
||||||
@ -164,7 +165,7 @@ func FormatForExternalPrompt(messages []api.Message, system bool) string {
|
|||||||
return sb.String()
|
return sb.String()
|
||||||
}
|
}
|
||||||
|
|
||||||
func GenerateTitle(ctx *lmcli.Context, messages []api.Message) (string, error) {
|
func GenerateTitle(ctx *lmcli.Context, messages []conversation.Message) (string, error) {
|
||||||
const systemPrompt = `You will be shown a conversation between a user and an AI assistant. Your task is to generate a short title (8 words or less) for the provided conversation that reflects the conversation's topic. Your response is expected to be in JSON in the format shown below.
|
const systemPrompt = `You will be shown a conversation between a user and an AI assistant. Your task is to generate a short title (8 words or less) for the provided conversation that reflects the conversation's topic. Your response is expected to be in JSON in the format shown below.
|
||||||
|
|
||||||
Example conversation:
|
Example conversation:
|
||||||
@ -189,19 +190,19 @@ Example response:
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Serialize the conversation to JSON
|
// Serialize the conversation to JSON
|
||||||
conversation, err := json.Marshal(msgs)
|
jsonBytes, err := json.Marshal(msgs)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return "", err
|
return "", err
|
||||||
}
|
}
|
||||||
|
|
||||||
generateRequest := []api.Message{
|
generateRequest := []conversation.Message{
|
||||||
{
|
{
|
||||||
Role: api.MessageRoleSystem,
|
Role: api.MessageRoleSystem,
|
||||||
Content: systemPrompt,
|
Content: systemPrompt,
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
Role: api.MessageRoleUser,
|
Role: api.MessageRoleUser,
|
||||||
Content: string(conversation),
|
Content: string(jsonBytes),
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -218,7 +219,7 @@ Example response:
|
|||||||
}
|
}
|
||||||
|
|
||||||
response, err := p.CreateChatCompletion(
|
response, err := p.CreateChatCompletion(
|
||||||
context.Background(), requestParams, generateRequest,
|
context.Background(), requestParams, conversation.MessagesToAPI(generateRequest),
|
||||||
)
|
)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return "", err
|
return "", err
|
||||||
@ -293,7 +294,7 @@ func ShowDelayedContent(content <-chan provider.Chunk) {
|
|||||||
// RenderConversation renders the given messages to TTY, with optional space
|
// RenderConversation renders the given messages to TTY, with optional space
|
||||||
// for a subsequent message. spaceForResponse controls how many '\n' characters
|
// for a subsequent message. spaceForResponse controls how many '\n' characters
|
||||||
// are printed immediately after the final message (1 if false, 2 if true)
|
// are printed immediately after the final message (1 if false, 2 if true)
|
||||||
func RenderConversation(ctx *lmcli.Context, messages []api.Message, spaceForResponse bool) {
|
func RenderConversation(ctx *lmcli.Context, messages []conversation.Message, spaceForResponse bool) {
|
||||||
l := len(messages)
|
l := len(messages)
|
||||||
for i, message := range messages {
|
for i, message := range messages {
|
||||||
RenderMessage(ctx, &message)
|
RenderMessage(ctx, &message)
|
||||||
@ -304,7 +305,7 @@ func RenderConversation(ctx *lmcli.Context, messages []api.Message, spaceForResp
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func RenderMessage(ctx *lmcli.Context, m *api.Message) {
|
func RenderMessage(ctx *lmcli.Context, m *conversation.Message) {
|
||||||
var messageAge string
|
var messageAge string
|
||||||
if m.CreatedAt.IsZero() {
|
if m.CreatedAt.IsZero() {
|
||||||
messageAge = "now"
|
messageAge = "now"
|
||||||
|
@ -24,7 +24,7 @@ func ViewCmd(ctx *lmcli.Context) *cobra.Command {
|
|||||||
shortName := args[0]
|
shortName := args[0]
|
||||||
conversation := cmdutil.LookupConversation(ctx, shortName)
|
conversation := cmdutil.LookupConversation(ctx, shortName)
|
||||||
|
|
||||||
messages, err := ctx.Store.PathToLeaf(conversation.SelectedRoot)
|
messages, err := ctx.Conversations.PathToLeaf(conversation.SelectedRoot)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("Could not retrieve messages for conversation %s: %v", conversation.ShortName.String, err)
|
return fmt.Errorf("Could not retrieve messages for conversation %s: %v", conversation.ShortName.String, err)
|
||||||
}
|
}
|
||||||
@ -37,7 +37,7 @@ func ViewCmd(ctx *lmcli.Context) *cobra.Command {
|
|||||||
if len(args) != 0 {
|
if len(args) != 0 {
|
||||||
return nil, compMode
|
return nil, compMode
|
||||||
}
|
}
|
||||||
return ctx.Store.ConversationShortNameCompletions(toComplete), compMode
|
return ctx.Conversations.ConversationShortNameCompletions(toComplete), compMode
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
|
98
pkg/conversation/conversation.go
Normal file
98
pkg/conversation/conversation.go
Normal file
@ -0,0 +1,98 @@
|
|||||||
|
package conversation
|
||||||
|
|
||||||
|
import (
|
||||||
|
"database/sql"
|
||||||
|
"database/sql/driver"
|
||||||
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"git.mlow.ca/mlow/lmcli/pkg/api"
|
||||||
|
)
|
||||||
|
|
||||||
|
type Conversation struct {
|
||||||
|
ID uint `gorm:"primaryKey"`
|
||||||
|
ShortName sql.NullString
|
||||||
|
Title string
|
||||||
|
SelectedRootID *uint
|
||||||
|
SelectedRoot *Message `gorm:"foreignKey:SelectedRootID"`
|
||||||
|
RootMessages []Message `gorm:"-:all"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type MessageMeta struct {
|
||||||
|
GenerationProvider *string `json:"generation_provider,omitempty"`
|
||||||
|
GenerationModel *string `json:"generation_model,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type Message struct {
|
||||||
|
ID uint `gorm:"primaryKey"`
|
||||||
|
CreatedAt time.Time
|
||||||
|
Metadata MessageMeta
|
||||||
|
|
||||||
|
ConversationID *uint `gorm:"index"`
|
||||||
|
Conversation *Conversation `gorm:"foreignKey:ConversationID"`
|
||||||
|
ParentID *uint
|
||||||
|
Parent *Message `gorm:"foreignKey:ParentID"`
|
||||||
|
Replies []Message `gorm:"foreignKey:ParentID"`
|
||||||
|
SelectedReplyID *uint
|
||||||
|
SelectedReply *Message `gorm:"foreignKey:SelectedReplyID"`
|
||||||
|
|
||||||
|
Role api.MessageRole
|
||||||
|
Content string
|
||||||
|
ToolCalls ToolCalls // a json array of tool calls (from the model)
|
||||||
|
ToolResults ToolResults // a json array of tool results
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *MessageMeta) Scan(value interface{}) error {
|
||||||
|
return json.Unmarshal(value.([]byte), m)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m MessageMeta) Value() (driver.Value, error) {
|
||||||
|
return json.Marshal(m)
|
||||||
|
}
|
||||||
|
|
||||||
|
type ToolCalls []api.ToolCall
|
||||||
|
|
||||||
|
func (tc *ToolCalls) Scan(value any) (err error) {
|
||||||
|
s := value.(string)
|
||||||
|
if value == nil || s == "" {
|
||||||
|
*tc = nil
|
||||||
|
return
|
||||||
|
}
|
||||||
|
err = json.Unmarshal([]byte(s), tc)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
func (tc ToolCalls) Value() (driver.Value, error) {
|
||||||
|
if len(tc) == 0 {
|
||||||
|
return "", nil
|
||||||
|
}
|
||||||
|
jsonBytes, err := json.Marshal(tc)
|
||||||
|
if err != nil {
|
||||||
|
return "", fmt.Errorf("Could not marshal ToolCalls to JSON: %v\n", err)
|
||||||
|
}
|
||||||
|
return string(jsonBytes), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
type ToolResults []api.ToolResult
|
||||||
|
|
||||||
|
func (tr *ToolResults) Scan(value any) (err error) {
|
||||||
|
s := value.(string)
|
||||||
|
if value == nil || s == "" {
|
||||||
|
*tr = nil
|
||||||
|
return
|
||||||
|
}
|
||||||
|
err = json.Unmarshal([]byte(s), tr)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
func (tr ToolResults) Value() (driver.Value, error) {
|
||||||
|
if len(tr) == 0 {
|
||||||
|
return "", nil
|
||||||
|
}
|
||||||
|
jsonBytes, err := json.Marshal([]api.ToolResult(tr))
|
||||||
|
if err != nil {
|
||||||
|
return "", fmt.Errorf("Could not marshal ToolResults to JSON: %v\n", err)
|
||||||
|
}
|
||||||
|
return string(jsonBytes), nil
|
||||||
|
}
|
@ -1,4 +1,4 @@
|
|||||||
package lmcli
|
package conversation
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"database/sql"
|
"database/sql"
|
||||||
@ -8,43 +8,57 @@ import (
|
|||||||
"strings"
|
"strings"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"git.mlow.ca/mlow/lmcli/pkg/api"
|
|
||||||
sqids "github.com/sqids/sqids-go"
|
sqids "github.com/sqids/sqids-go"
|
||||||
"gorm.io/gorm"
|
"gorm.io/gorm"
|
||||||
)
|
)
|
||||||
|
|
||||||
type ConversationStore interface {
|
// Repo exposes low-level message and conversation management. See
|
||||||
ConversationByShortName(shortName string) (*api.Conversation, error)
|
// Service for high-level helpers
|
||||||
|
type Repo interface {
|
||||||
|
// LatestConversationMessages returns a slice of all conversations ordered by when they were last updated (newest to oldest)
|
||||||
|
LatestConversationMessages() ([]Message, error)
|
||||||
|
|
||||||
|
FindConversationByShortName(shortName string) (*Conversation, error)
|
||||||
ConversationShortNameCompletions(search string) []string
|
ConversationShortNameCompletions(search string) []string
|
||||||
RootMessages(conversationID uint) ([]api.Message, error)
|
GetConversationByID(int uint) (*Conversation, error)
|
||||||
LatestConversationMessages() ([]api.Message, error)
|
GetRootMessages(conversationID uint) ([]Message, error)
|
||||||
|
|
||||||
StartConversation(messages ...api.Message) (*api.Conversation, []api.Message, error)
|
CreateConversation(title string) (*Conversation, error)
|
||||||
UpdateConversation(conversation *api.Conversation) error
|
UpdateConversation(*Conversation) error
|
||||||
DeleteConversation(conversation *api.Conversation) error
|
DeleteConversation(*Conversation) error
|
||||||
CloneConversation(toClone api.Conversation) (*api.Conversation, uint, error)
|
|
||||||
|
|
||||||
MessageByID(messageID uint) (*api.Message, error)
|
GetMessageByID(messageID uint) (*Message, error)
|
||||||
MessageReplies(messageID uint) ([]api.Message, error)
|
|
||||||
|
|
||||||
UpdateMessage(message *api.Message) error
|
SaveMessage(message Message) (*Message, error)
|
||||||
DeleteMessage(message *api.Message, prune bool) error
|
UpdateMessage(message *Message) error
|
||||||
CloneBranch(toClone api.Message) (*api.Message, uint, error)
|
DeleteMessage(message *Message, prune bool) error
|
||||||
Reply(to *api.Message, messages ...api.Message) ([]api.Message, error)
|
CloneBranch(toClone Message) (*Message, uint, error)
|
||||||
|
Reply(to *Message, messages ...Message) ([]Message, error)
|
||||||
|
|
||||||
PathToRoot(message *api.Message) ([]api.Message, error)
|
PathToRoot(message *Message) ([]Message, error)
|
||||||
PathToLeaf(message *api.Message) ([]api.Message, error)
|
PathToLeaf(message *Message) ([]Message, error)
|
||||||
|
|
||||||
|
// Retrieves and return the "selected thread" of the conversation.
|
||||||
|
// The "selected thread" of the conversation is a chain of messages
|
||||||
|
// starting from the Conversation's SelectedRoot Message, following each
|
||||||
|
// Message's SelectedReply until the tail Message is reached.
|
||||||
|
GetSelectedThread(*Conversation) ([]Message, error)
|
||||||
|
|
||||||
|
// Start a new conversation with the given messages
|
||||||
|
StartConversation(messages ...Message) (*Conversation, []Message, error)
|
||||||
|
|
||||||
|
CloneConversation(toClone Conversation) (*Conversation, uint, error)
|
||||||
}
|
}
|
||||||
|
|
||||||
type SQLStore struct {
|
type repo struct {
|
||||||
db *gorm.DB
|
db *gorm.DB
|
||||||
sqids *sqids.Sqids
|
sqids *sqids.Sqids
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewSQLStore(db *gorm.DB) (*SQLStore, error) {
|
func NewRepo(db *gorm.DB) (Repo, error) {
|
||||||
models := []any{
|
models := []any{
|
||||||
&api.Conversation{},
|
&Conversation{},
|
||||||
&api.Message{},
|
&Message{},
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, x := range models {
|
for _, x := range models {
|
||||||
@ -55,12 +69,70 @@ func NewSQLStore(db *gorm.DB) (*SQLStore, error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
_sqids, _ := sqids.New(sqids.Options{MinLength: 4})
|
_sqids, _ := sqids.New(sqids.Options{MinLength: 4})
|
||||||
return &SQLStore{db, _sqids}, nil
|
return &repo{db, _sqids}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *SQLStore) createConversation() (*api.Conversation, error) {
|
func (s *repo) LatestConversationMessages() ([]Message, error) {
|
||||||
|
var latestMessages []Message
|
||||||
|
|
||||||
|
subQuery := s.db.Model(&Message{}).
|
||||||
|
Select("MAX(created_at) as max_created_at, conversation_id").
|
||||||
|
Group("conversation_id")
|
||||||
|
|
||||||
|
err := s.db.Model(&Message{}).
|
||||||
|
Joins("JOIN (?) as sub on messages.conversation_id = sub.conversation_id AND messages.created_at = sub.max_created_at", subQuery).
|
||||||
|
Group("messages.conversation_id").
|
||||||
|
Order("created_at DESC").
|
||||||
|
Preload("Conversation.SelectedRoot").
|
||||||
|
Find(&latestMessages).Error
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
return latestMessages, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *repo) FindConversationByShortName(shortName string) (*Conversation, error) {
|
||||||
|
if shortName == "" {
|
||||||
|
return nil, errors.New("shortName is empty")
|
||||||
|
}
|
||||||
|
var conversation Conversation
|
||||||
|
err := s.db.Preload("SelectedRoot").Where("short_name = ?", shortName).Find(&conversation).Error
|
||||||
|
return &conversation, err
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *repo) ConversationShortNameCompletions(shortName string) []string {
|
||||||
|
var conversations []Conversation
|
||||||
|
// ignore error for completions
|
||||||
|
s.db.Find(&conversations)
|
||||||
|
completions := make([]string, 0, len(conversations))
|
||||||
|
for _, conversation := range conversations {
|
||||||
|
if shortName == "" || strings.HasPrefix(conversation.ShortName.String, shortName) {
|
||||||
|
completions = append(completions, fmt.Sprintf("%s\t%s", conversation.ShortName.String, conversation.Title))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return completions
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *repo) GetConversationByID(id uint) (*Conversation, error) {
|
||||||
|
var conversation Conversation
|
||||||
|
err := s.db.Preload("SelectedRoot").Where("id = ?", id).Find(&conversation).Error
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("Cannot get conversation %d: %v", id, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
rootMessages, err := s.GetRootMessages(id)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("Could not load conversation's root messages %d: %v", id, err)
|
||||||
|
}
|
||||||
|
conversation.RootMessages = rootMessages
|
||||||
|
return &conversation, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *repo) CreateConversation(title string) (*Conversation, error) {
|
||||||
// Create the new conversation
|
// Create the new conversation
|
||||||
c := &api.Conversation{}
|
c := &Conversation{Title: title}
|
||||||
err := s.db.Save(c).Error
|
err := s.db.Save(c).Error
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
@ -75,159 +147,54 @@ func (s *SQLStore) createConversation() (*api.Conversation, error) {
|
|||||||
return c, nil
|
return c, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *SQLStore) UpdateConversation(c *api.Conversation) error {
|
func (s *repo) UpdateConversation(c *Conversation) error {
|
||||||
if c == nil || c.ID == 0 {
|
if c == nil || c.ID == 0 {
|
||||||
return fmt.Errorf("Conversation is nil or invalid (missing ID)")
|
return fmt.Errorf("Conversation is nil or invalid (missing ID)")
|
||||||
}
|
}
|
||||||
return s.db.Updates(c).Error
|
return s.db.Updates(c).Error
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *SQLStore) DeleteConversation(c *api.Conversation) error {
|
func (s *repo) DeleteConversation(c *Conversation) error {
|
||||||
|
if c == nil || c.ID == 0 {
|
||||||
|
return fmt.Errorf("Conversation is nil or invalid (missing ID)")
|
||||||
|
}
|
||||||
// Delete messages first
|
// Delete messages first
|
||||||
err := s.db.Where("conversation_id = ?", c.ID).Delete(&api.Message{}).Error
|
err := s.db.Where("conversation_id = ?", c.ID).Delete(&Message{}).Error
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
return s.db.Delete(c).Error
|
return s.db.Delete(c).Error
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *SQLStore) DeleteMessage(message *api.Message, prune bool) error {
|
func (s *repo) SaveMessage(m Message) (*Message, error) {
|
||||||
panic("Not yet implemented")
|
if m.Conversation == nil {
|
||||||
//return s.db.Delete(&message).Error
|
return nil, fmt.Errorf("Can't save a message without a conversation (this is a bug)")
|
||||||
|
}
|
||||||
|
newMessage := m
|
||||||
|
newMessage.ID = 0
|
||||||
|
return &newMessage, s.db.Create(&newMessage).Error
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *SQLStore) UpdateMessage(m *api.Message) error {
|
func (s *repo) UpdateMessage(m *Message) error {
|
||||||
if m == nil || m.ID == 0 {
|
if m == nil || m.ID == 0 {
|
||||||
return fmt.Errorf("Message is nil or invalid (missing ID)")
|
return fmt.Errorf("Message is nil or invalid (missing ID)")
|
||||||
}
|
}
|
||||||
return s.db.Updates(m).Error
|
return s.db.Updates(m).Error
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *SQLStore) ConversationShortNameCompletions(shortName string) []string {
|
func (s *repo) DeleteMessage(message *Message, prune bool) error {
|
||||||
var conversations []api.Conversation
|
return s.db.Delete(&message).Error
|
||||||
// ignore error for completions
|
|
||||||
s.db.Find(&conversations)
|
|
||||||
completions := make([]string, 0, len(conversations))
|
|
||||||
for _, conversation := range conversations {
|
|
||||||
if shortName == "" || strings.HasPrefix(conversation.ShortName.String, shortName) {
|
|
||||||
completions = append(completions, fmt.Sprintf("%s\t%s", conversation.ShortName.String, conversation.Title))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return completions
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *SQLStore) ConversationByShortName(shortName string) (*api.Conversation, error) {
|
func (s *repo) GetMessageByID(messageID uint) (*Message, error) {
|
||||||
if shortName == "" {
|
var message Message
|
||||||
return nil, errors.New("shortName is empty")
|
|
||||||
}
|
|
||||||
var conversation api.Conversation
|
|
||||||
err := s.db.Preload("SelectedRoot").Where("short_name = ?", shortName).Find(&conversation).Error
|
|
||||||
return &conversation, err
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *SQLStore) RootMessages(conversationID uint) ([]api.Message, error) {
|
|
||||||
var rootMessages []api.Message
|
|
||||||
err := s.db.Where("conversation_id = ? AND parent_id IS NULL", conversationID).Find(&rootMessages).Error
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
return rootMessages, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *SQLStore) MessageByID(messageID uint) (*api.Message, error) {
|
|
||||||
var message api.Message
|
|
||||||
err := s.db.Preload("Parent").Preload("Replies").Preload("SelectedReply").Where("id = ?", messageID).Find(&message).Error
|
err := s.db.Preload("Parent").Preload("Replies").Preload("SelectedReply").Where("id = ?", messageID).Find(&message).Error
|
||||||
return &message, err
|
return &message, err
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *SQLStore) MessageReplies(messageID uint) ([]api.Message, error) {
|
// Reply to a message with a series of messages (each followed by the next)
|
||||||
var replies []api.Message
|
func (s *repo) Reply(to *Message, messages ...Message) ([]Message, error) {
|
||||||
err := s.db.Where("parent_id = ?", messageID).Find(&replies).Error
|
var savedMessages []Message
|
||||||
return replies, err
|
|
||||||
}
|
|
||||||
|
|
||||||
// StartConversation starts a new conversation with the provided messages
|
|
||||||
func (s *SQLStore) StartConversation(messages ...api.Message) (*api.Conversation, []api.Message, error) {
|
|
||||||
if len(messages) == 0 {
|
|
||||||
return nil, nil, fmt.Errorf("Must provide at least 1 message")
|
|
||||||
}
|
|
||||||
|
|
||||||
// Create new conversation
|
|
||||||
conversation, err := s.createConversation()
|
|
||||||
if err != nil {
|
|
||||||
return nil, nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
// Create first message
|
|
||||||
messages[0].Conversation = conversation
|
|
||||||
err = s.db.Create(&messages[0]).Error
|
|
||||||
if err != nil {
|
|
||||||
return nil, nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
// Update conversation's selected root message
|
|
||||||
conversation.SelectedRoot = &messages[0]
|
|
||||||
err = s.UpdateConversation(conversation)
|
|
||||||
if err != nil {
|
|
||||||
return nil, nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
// Add additional replies to conversation
|
|
||||||
if len(messages) > 1 {
|
|
||||||
newMessages, err := s.Reply(&messages[0], messages[1:]...)
|
|
||||||
if err != nil {
|
|
||||||
return nil, nil, err
|
|
||||||
}
|
|
||||||
messages = append([]api.Message{messages[0]}, newMessages...)
|
|
||||||
}
|
|
||||||
return conversation, messages, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// CloneConversation clones the given conversation and all of its root meesages
|
|
||||||
func (s *SQLStore) CloneConversation(toClone api.Conversation) (*api.Conversation, uint, error) {
|
|
||||||
rootMessages, err := s.RootMessages(toClone.ID)
|
|
||||||
if err != nil {
|
|
||||||
return nil, 0, err
|
|
||||||
}
|
|
||||||
|
|
||||||
clone, err := s.createConversation()
|
|
||||||
if err != nil {
|
|
||||||
return nil, 0, fmt.Errorf("Could not create clone: %s", err)
|
|
||||||
}
|
|
||||||
clone.Title = toClone.Title + " - Clone"
|
|
||||||
|
|
||||||
var errors []error
|
|
||||||
var messageCnt uint = 0
|
|
||||||
for _, root := range rootMessages {
|
|
||||||
messageCnt++
|
|
||||||
newRoot := root
|
|
||||||
newRoot.ConversationID = &clone.ID
|
|
||||||
|
|
||||||
cloned, count, err := s.CloneBranch(newRoot)
|
|
||||||
if err != nil {
|
|
||||||
errors = append(errors, err)
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
messageCnt += count
|
|
||||||
|
|
||||||
if root.ID == *toClone.SelectedRootID {
|
|
||||||
clone.SelectedRootID = &cloned.ID
|
|
||||||
if err := s.UpdateConversation(clone); err != nil {
|
|
||||||
errors = append(errors, fmt.Errorf("Could not set selected root on clone: %v", err))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if len(errors) > 0 {
|
|
||||||
return nil, 0, fmt.Errorf("Messages failed to be cloned: %v", errors)
|
|
||||||
}
|
|
||||||
|
|
||||||
return clone, messageCnt, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// Reply to a message with a series of messages (each following the next)
|
|
||||||
func (s *SQLStore) Reply(to *api.Message, messages ...api.Message) ([]api.Message, error) {
|
|
||||||
var savedMessages []api.Message
|
|
||||||
|
|
||||||
err := s.db.Transaction(func(tx *gorm.DB) error {
|
err := s.db.Transaction(func(tx *gorm.DB) error {
|
||||||
currentParent := to
|
currentParent := to
|
||||||
@ -262,17 +229,14 @@ func (s *SQLStore) Reply(to *api.Message, messages ...api.Message) ([]api.Messag
|
|||||||
// CloneBranch returns a deep clone of the given message and its replies, returning
|
// CloneBranch returns a deep clone of the given message and its replies, returning
|
||||||
// a new message object. The new message will be attached to the same parent as
|
// a new message object. The new message will be attached to the same parent as
|
||||||
// the messageToClone
|
// the messageToClone
|
||||||
func (s *SQLStore) CloneBranch(messageToClone api.Message) (*api.Message, uint, error) {
|
func (s *repo) CloneBranch(messageToClone Message) (*Message, uint, error) {
|
||||||
newMessage := messageToClone
|
newMessage := messageToClone
|
||||||
newMessage.ID = 0
|
newMessage.ID = 0
|
||||||
newMessage.Replies = nil
|
newMessage.Replies = nil
|
||||||
newMessage.SelectedReplyID = nil
|
newMessage.SelectedReplyID = nil
|
||||||
newMessage.SelectedReply = nil
|
newMessage.SelectedReply = nil
|
||||||
|
|
||||||
originalReplies, err := s.MessageReplies(messageToClone.ID)
|
originalReplies := messageToClone.Replies
|
||||||
if err != nil {
|
|
||||||
return nil, 0, fmt.Errorf("Could not fetch message %d replies: %v", messageToClone.ID, err)
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := s.db.Create(&newMessage).Error; err != nil {
|
if err := s.db.Create(&newMessage).Error; err != nil {
|
||||||
return nil, 0, fmt.Errorf("Could not clone message: %s", err)
|
return nil, 0, fmt.Errorf("Could not clone message: %s", err)
|
||||||
@ -304,19 +268,19 @@ func (s *SQLStore) CloneBranch(messageToClone api.Message) (*api.Message, uint,
|
|||||||
return &newMessage, replyCount, nil
|
return &newMessage, replyCount, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func fetchMessages(db *gorm.DB) ([]api.Message, error) {
|
func fetchMessages(db *gorm.DB) ([]Message, error) {
|
||||||
var messages []api.Message
|
var messages []Message
|
||||||
if err := db.Preload("Conversation").Find(&messages).Error; err != nil {
|
if err := db.Preload("Conversation").Find(&messages).Error; err != nil {
|
||||||
return nil, fmt.Errorf("Could not fetch messages: %v", err)
|
return nil, fmt.Errorf("Could not fetch messages: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
messageMap := make(map[uint]api.Message)
|
messageMap := make(map[uint]Message)
|
||||||
for i, message := range messages {
|
for i, message := range messages {
|
||||||
messageMap[messages[i].ID] = message
|
messageMap[messages[i].ID] = message
|
||||||
}
|
}
|
||||||
|
|
||||||
// Create a map to store replies by their parent ID
|
// Create a map to store replies by their parent ID
|
||||||
repliesMap := make(map[uint][]api.Message)
|
repliesMap := make(map[uint][]Message)
|
||||||
for i, message := range messages {
|
for i, message := range messages {
|
||||||
if messages[i].ParentID != nil {
|
if messages[i].ParentID != nil {
|
||||||
repliesMap[*messages[i].ParentID] = append(repliesMap[*messages[i].ParentID], message)
|
repliesMap[*messages[i].ParentID] = append(repliesMap[*messages[i].ParentID], message)
|
||||||
@ -326,7 +290,7 @@ func fetchMessages(db *gorm.DB) ([]api.Message, error) {
|
|||||||
// Assign replies, parent, and selected reply to each message
|
// Assign replies, parent, and selected reply to each message
|
||||||
for i := range messages {
|
for i := range messages {
|
||||||
if replies, exists := repliesMap[messages[i].ID]; exists {
|
if replies, exists := repliesMap[messages[i].ID]; exists {
|
||||||
messages[i].Replies = make([]api.Message, len(replies))
|
messages[i].Replies = make([]Message, len(replies))
|
||||||
for j, m := range replies {
|
for j, m := range replies {
|
||||||
messages[i].Replies[j] = m
|
messages[i].Replies[j] = m
|
||||||
}
|
}
|
||||||
@ -345,21 +309,51 @@ func fetchMessages(db *gorm.DB) ([]api.Message, error) {
|
|||||||
return messages, nil
|
return messages, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *SQLStore) buildPath(message *api.Message, getNext func(*api.Message) *uint) ([]api.Message, error) {
|
func (r repo) GetRootMessages(conversationID uint) ([]Message, error) {
|
||||||
var messages []api.Message
|
var rootMessages []Message
|
||||||
|
err := r.db.Where("conversation_id = ? AND parent_id IS NULL", conversationID).Find(&rootMessages).Error
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("Could not retrieve root messages for conversation %d: %v", conversationID, err)
|
||||||
|
}
|
||||||
|
return rootMessages, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *repo) buildPath(message *Message, getNext func(*Message) *uint) ([]Message, error) {
|
||||||
|
var messages []Message
|
||||||
messages, err := fetchMessages(s.db.Where("conversation_id = ?", message.ConversationID))
|
messages, err := fetchMessages(s.db.Where("conversation_id = ?", message.ConversationID))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
// Create a map to store messages by their ID
|
// Create a map to store messages by their ID
|
||||||
messageMap := make(map[uint]*api.Message)
|
messageMap := make(map[uint]*Message, len(messages))
|
||||||
for i := range messages {
|
for i := range messages {
|
||||||
messageMap[messages[i].ID] = &messages[i]
|
messageMap[messages[i].ID] = &messages[i]
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Construct Replies
|
||||||
|
repliesMap := make(map[uint][]*Message, len(messages))
|
||||||
|
for _, m := range messageMap {
|
||||||
|
if m.ParentID == nil {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if p, ok := messageMap[*m.ParentID]; ok {
|
||||||
|
repliesMap[p.ID] = append(repliesMap[p.ID], m)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Add replies to messages
|
||||||
|
for _, m := range messageMap {
|
||||||
|
if replies, ok := repliesMap[m.ID]; ok {
|
||||||
|
m.Replies = make([]Message, len(replies))
|
||||||
|
for idx, reply := range replies {
|
||||||
|
m.Replies[idx] = *reply
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// Build the path
|
// Build the path
|
||||||
var path []api.Message
|
var path []Message
|
||||||
nextID := &message.ID
|
nextID := &message.ID
|
||||||
|
|
||||||
for {
|
for {
|
||||||
@ -382,12 +376,12 @@ func (s *SQLStore) buildPath(message *api.Message, getNext func(*api.Message) *u
|
|||||||
// PathToRoot traverses the provided message's Parent until reaching the tree
|
// PathToRoot traverses the provided message's Parent until reaching the tree
|
||||||
// root and returns a slice of all messages traversed in chronological order
|
// root and returns a slice of all messages traversed in chronological order
|
||||||
// (starting with the root and ending with the message provided)
|
// (starting with the root and ending with the message provided)
|
||||||
func (s *SQLStore) PathToRoot(message *api.Message) ([]api.Message, error) {
|
func (s *repo) PathToRoot(message *Message) ([]Message, error) {
|
||||||
if message == nil || message.ID <= 0 {
|
if message == nil || message.ID <= 0 {
|
||||||
return nil, fmt.Errorf("Message is nil or has invalid ID")
|
return nil, fmt.Errorf("Message is nil or has invalid ID")
|
||||||
}
|
}
|
||||||
|
|
||||||
path, err := s.buildPath(message, func(m *api.Message) *uint {
|
path, err := s.buildPath(message, func(m *Message) *uint {
|
||||||
return m.ParentID
|
return m.ParentID
|
||||||
})
|
})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@ -401,33 +395,99 @@ func (s *SQLStore) PathToRoot(message *api.Message) ([]api.Message, error) {
|
|||||||
// PathToLeaf traverses the provided message's SelectedReply until reaching a
|
// PathToLeaf traverses the provided message's SelectedReply until reaching a
|
||||||
// tree leaf and returns a slice of all messages traversed in chronological
|
// tree leaf and returns a slice of all messages traversed in chronological
|
||||||
// order (starting with the message provided and ending with the leaf)
|
// order (starting with the message provided and ending with the leaf)
|
||||||
func (s *SQLStore) PathToLeaf(message *api.Message) ([]api.Message, error) {
|
func (s *repo) PathToLeaf(message *Message) ([]Message, error) {
|
||||||
if message == nil || message.ID <= 0 {
|
if message == nil || message.ID <= 0 {
|
||||||
return nil, fmt.Errorf("Message is nil or has invalid ID")
|
return nil, fmt.Errorf("Message is nil or has invalid ID")
|
||||||
}
|
}
|
||||||
|
|
||||||
return s.buildPath(message, func(m *api.Message) *uint {
|
return s.buildPath(message, func(m *Message) *uint {
|
||||||
return m.SelectedReplyID
|
return m.SelectedReplyID
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *SQLStore) LatestConversationMessages() ([]api.Message, error) {
|
func (s *repo) StartConversation(messages ...Message) (*Conversation, []Message, error) {
|
||||||
var latestMessages []api.Message
|
if len(messages) == 0 {
|
||||||
|
return nil, nil, fmt.Errorf("Must provide at least 1 message")
|
||||||
subQuery := s.db.Model(&api.Message{}).
|
|
||||||
Select("MAX(created_at) as max_created_at, conversation_id").
|
|
||||||
Group("conversation_id")
|
|
||||||
|
|
||||||
err := s.db.Model(&api.Message{}).
|
|
||||||
Joins("JOIN (?) as sub on messages.conversation_id = sub.conversation_id AND messages.created_at = sub.max_created_at", subQuery).
|
|
||||||
Group("messages.conversation_id").
|
|
||||||
Order("created_at DESC").
|
|
||||||
Preload("Conversation.SelectedRoot").
|
|
||||||
Find(&latestMessages).Error
|
|
||||||
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
}
|
||||||
|
|
||||||
return latestMessages, nil
|
// Create new conversation
|
||||||
|
conversation, err := s.CreateConversation("")
|
||||||
|
if err != nil {
|
||||||
|
return nil, nil, err
|
||||||
|
}
|
||||||
|
messages[0].Conversation = conversation
|
||||||
|
|
||||||
|
// Create first message
|
||||||
|
firstMessage, err := s.SaveMessage(messages[0])
|
||||||
|
if err != nil {
|
||||||
|
return nil, nil, err
|
||||||
|
}
|
||||||
|
messages[0] = *firstMessage
|
||||||
|
|
||||||
|
// Update conversation's selected root message
|
||||||
|
conversation.RootMessages = []Message{messages[0]}
|
||||||
|
conversation.SelectedRoot = &messages[0]
|
||||||
|
err = s.UpdateConversation(conversation)
|
||||||
|
if err != nil {
|
||||||
|
return nil, nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// Add additional replies to conversation
|
||||||
|
if len(messages) > 1 {
|
||||||
|
newMessages, err := s.Reply(&messages[0], messages[1:]...)
|
||||||
|
if err != nil {
|
||||||
|
return nil, nil, err
|
||||||
|
}
|
||||||
|
messages = append([]Message{messages[0]}, newMessages...)
|
||||||
|
}
|
||||||
|
return conversation, messages, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
// CloneConversation clones the given conversation and all of its meesages
|
||||||
|
func (s *repo) CloneConversation(toClone Conversation) (*Conversation, uint, error) {
|
||||||
|
rootMessages, err := s.GetRootMessages(toClone.ID)
|
||||||
|
if err != nil {
|
||||||
|
return nil, 0, fmt.Errorf("Could not create clone: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
clone, err := s.CreateConversation(toClone.Title + " - Clone")
|
||||||
|
if err != nil {
|
||||||
|
return nil, 0, fmt.Errorf("Could not create clone: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
var errors []error
|
||||||
|
var messageCnt uint = 0
|
||||||
|
for _, root := range rootMessages {
|
||||||
|
messageCnt++
|
||||||
|
newRoot := root
|
||||||
|
newRoot.ConversationID = &clone.ID
|
||||||
|
|
||||||
|
cloned, count, err := s.CloneBranch(newRoot)
|
||||||
|
if err != nil {
|
||||||
|
errors = append(errors, err)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
messageCnt += count
|
||||||
|
|
||||||
|
if root.ID == *toClone.SelectedRootID {
|
||||||
|
clone.SelectedRootID = &cloned.ID
|
||||||
|
if err := s.UpdateConversation(clone); err != nil {
|
||||||
|
errors = append(errors, fmt.Errorf("Could not set selected root on clone: %v", err))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(errors) > 0 {
|
||||||
|
return nil, 0, fmt.Errorf("Messages failed to be cloned: %v", errors)
|
||||||
|
}
|
||||||
|
|
||||||
|
return clone, messageCnt, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *repo) GetSelectedThread(c *Conversation) ([]Message, error) {
|
||||||
|
if c.SelectedRoot == nil {
|
||||||
|
return nil, fmt.Errorf("No SelectedRoot on conversation - this is a bug")
|
||||||
|
}
|
||||||
|
return s.PathToLeaf(c.SelectedRoot)
|
||||||
}
|
}
|
55
pkg/conversation/tools.go
Normal file
55
pkg/conversation/tools.go
Normal file
@ -0,0 +1,55 @@
|
|||||||
|
package conversation
|
||||||
|
|
||||||
|
import (
|
||||||
|
"git.mlow.ca/mlow/lmcli/pkg/api"
|
||||||
|
)
|
||||||
|
|
||||||
|
// ApplySystemPrompt updates the contents of an existing system Message if it
|
||||||
|
// exists, or returns a new slice with the system Message prepended.
|
||||||
|
func ApplySystemPrompt(m []Message, system string, force bool) []Message {
|
||||||
|
if len(m) > 0 && m[0].Role == api.MessageRoleSystem {
|
||||||
|
if force {
|
||||||
|
m[0].Content = system
|
||||||
|
}
|
||||||
|
return m
|
||||||
|
} else {
|
||||||
|
return append([]Message{{
|
||||||
|
Role: api.MessageRoleSystem,
|
||||||
|
Content: system,
|
||||||
|
}}, m...)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func MessageToAPI(m Message) api.Message {
|
||||||
|
return api.Message{
|
||||||
|
Role: m.Role,
|
||||||
|
Content: m.Content,
|
||||||
|
ToolCalls: m.ToolCalls,
|
||||||
|
ToolResults: m.ToolResults,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func MessagesToAPI(messages []Message) []api.Message {
|
||||||
|
ret := make([]api.Message, 0, len(messages))
|
||||||
|
for _, m := range messages {
|
||||||
|
ret = append(ret, MessageToAPI(m))
|
||||||
|
}
|
||||||
|
return ret
|
||||||
|
}
|
||||||
|
|
||||||
|
func MessageFromAPI(m api.Message) Message {
|
||||||
|
return Message{
|
||||||
|
Role: m.Role,
|
||||||
|
Content: m.Content,
|
||||||
|
ToolCalls: m.ToolCalls,
|
||||||
|
ToolResults: m.ToolResults,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func MessagesFromAPI(messages []api.Message) []Message {
|
||||||
|
ret := make([]Message, 0, len(messages))
|
||||||
|
for _, m := range messages {
|
||||||
|
ret = append(ret, MessageFromAPI(m))
|
||||||
|
}
|
||||||
|
return ret
|
||||||
|
}
|
@ -12,11 +12,12 @@ import (
|
|||||||
|
|
||||||
"git.mlow.ca/mlow/lmcli/pkg/agents"
|
"git.mlow.ca/mlow/lmcli/pkg/agents"
|
||||||
"git.mlow.ca/mlow/lmcli/pkg/api"
|
"git.mlow.ca/mlow/lmcli/pkg/api"
|
||||||
"git.mlow.ca/mlow/lmcli/pkg/api/provider"
|
"git.mlow.ca/mlow/lmcli/pkg/provider"
|
||||||
"git.mlow.ca/mlow/lmcli/pkg/api/provider/anthropic"
|
"git.mlow.ca/mlow/lmcli/pkg/provider/anthropic"
|
||||||
"git.mlow.ca/mlow/lmcli/pkg/api/provider/google"
|
"git.mlow.ca/mlow/lmcli/pkg/provider/google"
|
||||||
"git.mlow.ca/mlow/lmcli/pkg/api/provider/ollama"
|
"git.mlow.ca/mlow/lmcli/pkg/provider/ollama"
|
||||||
"git.mlow.ca/mlow/lmcli/pkg/api/provider/openai"
|
"git.mlow.ca/mlow/lmcli/pkg/provider/openai"
|
||||||
|
"git.mlow.ca/mlow/lmcli/pkg/conversation"
|
||||||
"git.mlow.ca/mlow/lmcli/pkg/util"
|
"git.mlow.ca/mlow/lmcli/pkg/util"
|
||||||
"git.mlow.ca/mlow/lmcli/pkg/util/tty"
|
"git.mlow.ca/mlow/lmcli/pkg/util/tty"
|
||||||
"gorm.io/driver/sqlite"
|
"gorm.io/driver/sqlite"
|
||||||
@ -33,7 +34,7 @@ type Agent struct {
|
|||||||
type Context struct {
|
type Context struct {
|
||||||
// high level app configuration, may be mutated at runtime
|
// high level app configuration, may be mutated at runtime
|
||||||
Config Config
|
Config Config
|
||||||
Store ConversationStore
|
Conversations conversation.Repo
|
||||||
Chroma *tty.ChromaHighlighter
|
Chroma *tty.ChromaHighlighter
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -44,7 +45,7 @@ func NewContext() (*Context, error) {
|
|||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
store, err := getConversationStore()
|
store, err := getConversationService()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
@ -69,17 +70,16 @@ func createOrOpenAppend(path string) (*os.File, error) {
|
|||||||
return file, nil
|
return file, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func getConversationStore() (ConversationStore, error) {
|
func getConversationService() (conversation.Repo, error) {
|
||||||
databaseFile := filepath.Join(dataDir(), "conversations.db")
|
databaseFile := filepath.Join(dataDir(), "conversations.db")
|
||||||
|
|
||||||
gormLogFile, err := createOrOpenAppend(filepath.Join(dataDir(), "database.log"))
|
gormLogFile, err := createOrOpenAppend(filepath.Join(dataDir(), "database.log"))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("Could not open database log file: %v", err)
|
return nil, fmt.Errorf("Could not open database log file: %v", err)
|
||||||
}
|
}
|
||||||
db, err := gorm.Open(sqlite.Open(databaseFile), &gorm.Config{
|
db, err := gorm.Open(sqlite.Open(databaseFile), &gorm.Config{
|
||||||
Logger: logger.New(log.New(gormLogFile, "", log.LstdFlags), logger.Config{
|
Logger: logger.New(log.New(gormLogFile, "\n", log.LstdFlags), logger.Config{
|
||||||
SlowThreshold: 200 * time.Millisecond,
|
SlowThreshold: 200 * time.Millisecond,
|
||||||
LogLevel: logger.Warn,
|
LogLevel: logger.Info,
|
||||||
IgnoreRecordNotFoundError: false,
|
IgnoreRecordNotFoundError: false,
|
||||||
Colorful: true,
|
Colorful: true,
|
||||||
}),
|
}),
|
||||||
@ -87,11 +87,11 @@ func getConversationStore() (ConversationStore, error) {
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("Error establishing connection to store: %v", err)
|
return nil, fmt.Errorf("Error establishing connection to store: %v", err)
|
||||||
}
|
}
|
||||||
store, err := NewSQLStore(db)
|
repo, err := conversation.NewRepo(db)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
return store, nil
|
return repo, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *Context) GetModels() (models []string) {
|
func (c *Context) GetModels() (models []string) {
|
||||||
|
@ -11,7 +11,7 @@ import (
|
|||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
"git.mlow.ca/mlow/lmcli/pkg/api"
|
"git.mlow.ca/mlow/lmcli/pkg/api"
|
||||||
"git.mlow.ca/mlow/lmcli/pkg/api/provider"
|
"git.mlow.ca/mlow/lmcli/pkg/provider"
|
||||||
)
|
)
|
||||||
|
|
||||||
const ANTHROPIC_VERSION = "2023-06-01"
|
const ANTHROPIC_VERSION = "2023-06-01"
|
||||||
@ -439,15 +439,9 @@ func convertResponseToMessage(resp ChatCompletionResponse) (*api.Message, error)
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
message := &api.Message{
|
|
||||||
Role: api.MessageRoleAssistant,
|
|
||||||
Content: content.String(),
|
|
||||||
ToolCalls: toolCalls,
|
|
||||||
}
|
|
||||||
|
|
||||||
if len(toolCalls) > 0 {
|
if len(toolCalls) > 0 {
|
||||||
message.Role = api.MessageRoleToolCall
|
return api.NewMessageWithToolCalls(content.String(), toolCalls), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
return message, nil
|
return api.NewMessageWithAssistant(content.String()), nil
|
||||||
}
|
}
|
@ -11,7 +11,7 @@ import (
|
|||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
"git.mlow.ca/mlow/lmcli/pkg/api"
|
"git.mlow.ca/mlow/lmcli/pkg/api"
|
||||||
"git.mlow.ca/mlow/lmcli/pkg/api/provider"
|
"git.mlow.ca/mlow/lmcli/pkg/provider"
|
||||||
)
|
)
|
||||||
|
|
||||||
type Client struct {
|
type Client struct {
|
||||||
@ -337,17 +337,10 @@ func (c *Client) CreateChatCompletion(
|
|||||||
}
|
}
|
||||||
|
|
||||||
if len(toolCalls) > 0 {
|
if len(toolCalls) > 0 {
|
||||||
return &api.Message{
|
return api.NewMessageWithToolCalls(content, convertToolCallToAPI(toolCalls)), nil
|
||||||
Role: api.MessageRoleToolCall,
|
|
||||||
Content: content,
|
|
||||||
ToolCalls: convertToolCallToAPI(toolCalls),
|
|
||||||
}, nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
return &api.Message{
|
return api.NewMessageWithAssistant(content), nil
|
||||||
Role: api.MessageRoleAssistant,
|
|
||||||
Content: content,
|
|
||||||
}, nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *Client) CreateChatCompletionStream(
|
func (c *Client) CreateChatCompletionStream(
|
||||||
@ -435,17 +428,9 @@ func (c *Client) CreateChatCompletionStream(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// If there are function calls, handle them and recurse
|
|
||||||
if len(toolCalls) > 0 {
|
if len(toolCalls) > 0 {
|
||||||
return &api.Message{
|
return api.NewMessageWithToolCalls(content.String(), convertToolCallToAPI(toolCalls)), nil
|
||||||
Role: api.MessageRoleToolCall,
|
|
||||||
Content: content.String(),
|
|
||||||
ToolCalls: convertToolCallToAPI(toolCalls),
|
|
||||||
}, nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
return &api.Message{
|
return api.NewMessageWithAssistant(content.String()), nil
|
||||||
Role: api.MessageRoleAssistant,
|
|
||||||
Content: content.String(),
|
|
||||||
}, nil
|
|
||||||
}
|
}
|
@ -11,7 +11,7 @@ import (
|
|||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
"git.mlow.ca/mlow/lmcli/pkg/api"
|
"git.mlow.ca/mlow/lmcli/pkg/api"
|
||||||
"git.mlow.ca/mlow/lmcli/pkg/api/provider"
|
"git.mlow.ca/mlow/lmcli/pkg/provider"
|
||||||
)
|
)
|
||||||
|
|
||||||
type OllamaClient struct {
|
type OllamaClient struct {
|
||||||
@ -115,10 +115,7 @@ func (c *OllamaClient) CreateChatCompletion(
|
|||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
return &api.Message{
|
return api.NewMessageWithAssistant(completionResp.Message.Content), nil
|
||||||
Role: api.MessageRoleAssistant,
|
|
||||||
Content: completionResp.Message.Content,
|
|
||||||
}, nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *OllamaClient) CreateChatCompletionStream(
|
func (c *OllamaClient) CreateChatCompletionStream(
|
||||||
@ -182,8 +179,5 @@ func (c *OllamaClient) CreateChatCompletionStream(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return &api.Message{
|
return api.NewMessageWithAssistant(content.String()), nil
|
||||||
Role: api.MessageRoleAssistant,
|
|
||||||
Content: content.String(),
|
|
||||||
}, nil
|
|
||||||
}
|
}
|
@ -11,7 +11,7 @@ import (
|
|||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
"git.mlow.ca/mlow/lmcli/pkg/api"
|
"git.mlow.ca/mlow/lmcli/pkg/api"
|
||||||
"git.mlow.ca/mlow/lmcli/pkg/api/provider"
|
"git.mlow.ca/mlow/lmcli/pkg/provider"
|
||||||
)
|
)
|
||||||
|
|
||||||
type OpenAIClient struct {
|
type OpenAIClient struct {
|
||||||
@ -253,17 +253,10 @@ func (c *OpenAIClient) CreateChatCompletion(
|
|||||||
|
|
||||||
toolCalls := choice.Message.ToolCalls
|
toolCalls := choice.Message.ToolCalls
|
||||||
if len(toolCalls) > 0 {
|
if len(toolCalls) > 0 {
|
||||||
return &api.Message{
|
return api.NewMessageWithToolCalls(content, convertToolCallToAPI(toolCalls)), nil
|
||||||
Role: api.MessageRoleToolCall,
|
|
||||||
Content: content,
|
|
||||||
ToolCalls: convertToolCallToAPI(toolCalls),
|
|
||||||
}, nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
return &api.Message{
|
return api.NewMessageWithAssistant(content), nil
|
||||||
Role: api.MessageRoleAssistant,
|
|
||||||
Content: content,
|
|
||||||
}, nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *OpenAIClient) CreateChatCompletionStream(
|
func (c *OpenAIClient) CreateChatCompletionStream(
|
||||||
@ -343,15 +336,8 @@ func (c *OpenAIClient) CreateChatCompletionStream(
|
|||||||
}
|
}
|
||||||
|
|
||||||
if len(toolCalls) > 0 {
|
if len(toolCalls) > 0 {
|
||||||
return &api.Message{
|
return api.NewMessageWithToolCalls(content.String(), convertToolCallToAPI(toolCalls)), nil
|
||||||
Role: api.MessageRoleToolCall,
|
|
||||||
Content: content.String(),
|
|
||||||
ToolCalls: convertToolCallToAPI(toolCalls),
|
|
||||||
}, nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
return &api.Message{
|
return api.NewMessageWithAssistant(content.String()), nil
|
||||||
Role: api.MessageRoleAssistant,
|
|
||||||
Content: content.String(),
|
|
||||||
}, nil
|
|
||||||
}
|
}
|
@ -6,30 +6,30 @@ import (
|
|||||||
|
|
||||||
"git.mlow.ca/mlow/lmcli/pkg/agents"
|
"git.mlow.ca/mlow/lmcli/pkg/agents"
|
||||||
"git.mlow.ca/mlow/lmcli/pkg/api"
|
"git.mlow.ca/mlow/lmcli/pkg/api"
|
||||||
"git.mlow.ca/mlow/lmcli/pkg/api/provider"
|
|
||||||
cmdutil "git.mlow.ca/mlow/lmcli/pkg/cmd/util"
|
cmdutil "git.mlow.ca/mlow/lmcli/pkg/cmd/util"
|
||||||
|
"git.mlow.ca/mlow/lmcli/pkg/conversation"
|
||||||
"git.mlow.ca/mlow/lmcli/pkg/lmcli"
|
"git.mlow.ca/mlow/lmcli/pkg/lmcli"
|
||||||
|
"git.mlow.ca/mlow/lmcli/pkg/provider"
|
||||||
"github.com/charmbracelet/lipgloss"
|
"github.com/charmbracelet/lipgloss"
|
||||||
)
|
)
|
||||||
|
|
||||||
type LoadedConversation struct {
|
type LoadedConversation struct {
|
||||||
Conv api.Conversation
|
Conv conversation.Conversation
|
||||||
LastReply api.Message
|
LastReply conversation.Message
|
||||||
}
|
}
|
||||||
|
|
||||||
type AppModel struct {
|
type AppModel struct {
|
||||||
Ctx *lmcli.Context
|
Ctx *lmcli.Context
|
||||||
Conversations []LoadedConversation
|
Conversations []LoadedConversation
|
||||||
Conversation *api.Conversation
|
Conversation *conversation.Conversation
|
||||||
RootMessages []api.Message
|
Messages []conversation.Message
|
||||||
Messages []api.Message
|
|
||||||
Model string
|
Model string
|
||||||
ProviderName string
|
ProviderName string
|
||||||
Provider provider.ChatCompletionProvider
|
Provider provider.ChatCompletionProvider
|
||||||
Agent *lmcli.Agent
|
Agent *lmcli.Agent
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewAppModel(ctx *lmcli.Context, initialConversation *api.Conversation) *AppModel {
|
func NewAppModel(ctx *lmcli.Context, initialConversation *conversation.Conversation) *AppModel {
|
||||||
app := &AppModel{
|
app := &AppModel{
|
||||||
Ctx: ctx,
|
Ctx: ctx,
|
||||||
Conversation: initialConversation,
|
Conversation: initialConversation,
|
||||||
@ -67,8 +67,7 @@ const (
|
|||||||
|
|
||||||
func (m *AppModel) ClearConversation() {
|
func (m *AppModel) ClearConversation() {
|
||||||
m.Conversation = nil
|
m.Conversation = nil
|
||||||
m.Messages = []api.Message{}
|
m.Messages = []conversation.Message{}
|
||||||
m.RootMessages = []api.Message{}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *AppModel) ApplySystemPrompt() {
|
func (m *AppModel) ApplySystemPrompt() {
|
||||||
@ -81,7 +80,7 @@ func (m *AppModel) ApplySystemPrompt() {
|
|||||||
system = m.Ctx.DefaultSystemPrompt()
|
system = m.Ctx.DefaultSystemPrompt()
|
||||||
}
|
}
|
||||||
if system != "" {
|
if system != "" {
|
||||||
m.Messages = api.ApplySystemPrompt(m.Messages, system, false)
|
m.Messages = conversation.ApplySystemPrompt(m.Messages, system, false)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -91,7 +90,7 @@ func (m *AppModel) NewConversation() {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (m *AppModel) LoadConversations() (error, []LoadedConversation) {
|
func (m *AppModel) LoadConversations() (error, []LoadedConversation) {
|
||||||
messages, err := m.Ctx.Store.LatestConversationMessages()
|
messages, err := m.Ctx.Conversations.LatestConversationMessages()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("Could not load conversations: %v", err), nil
|
return fmt.Errorf("Could not load conversations: %v", err), nil
|
||||||
}
|
}
|
||||||
@ -106,42 +105,34 @@ func (m *AppModel) LoadConversations() (error, []LoadedConversation) {
|
|||||||
return nil, conversations
|
return nil, conversations
|
||||||
}
|
}
|
||||||
|
|
||||||
func (a *AppModel) LoadConversationRootMessages() ([]api.Message, error) {
|
func (a *AppModel) LoadConversationMessages() ([]conversation.Message, error) {
|
||||||
messages, err := a.Ctx.Store.RootMessages(a.Conversation.ID)
|
messages, err := a.Ctx.Conversations.PathToLeaf(a.Conversation.SelectedRoot)
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("Could not load conversation root messages: %v %v", a.Conversation.SelectedRoot, err)
|
|
||||||
}
|
|
||||||
return messages, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (a *AppModel) LoadConversationMessages() ([]api.Message, error) {
|
|
||||||
messages, err := a.Ctx.Store.PathToLeaf(a.Conversation.SelectedRoot)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("Could not load conversation messages: %v %v", a.Conversation.SelectedRoot, err)
|
return nil, fmt.Errorf("Could not load conversation messages: %v %v", a.Conversation.SelectedRoot, err)
|
||||||
}
|
}
|
||||||
return messages, nil
|
return messages, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (a *AppModel) GenerateConversationTitle(messages []api.Message) (string, error) {
|
func (a *AppModel) GenerateConversationTitle(messages []conversation.Message) (string, error) {
|
||||||
return cmdutil.GenerateTitle(a.Ctx, messages)
|
return cmdutil.GenerateTitle(a.Ctx, messages)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (a *AppModel) UpdateConversationTitle(conversation *api.Conversation) error {
|
func (a *AppModel) UpdateConversationTitle(conversation *conversation.Conversation) error {
|
||||||
return a.Ctx.Store.UpdateConversation(conversation)
|
return a.Ctx.Conversations.UpdateConversation(conversation)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (a *AppModel) CloneMessage(message api.Message, selected bool) (*api.Message, error) {
|
func (a *AppModel) CloneMessage(message conversation.Message, selected bool) (*conversation.Message, error) {
|
||||||
msg, _, err := a.Ctx.Store.CloneBranch(message)
|
msg, _, err := a.Ctx.Conversations.CloneBranch(message)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("Could not clone message: %v", err)
|
return nil, fmt.Errorf("Could not clone message: %v", err)
|
||||||
}
|
}
|
||||||
if selected {
|
if selected {
|
||||||
if msg.Parent == nil {
|
if msg.Parent == nil {
|
||||||
msg.Conversation.SelectedRoot = msg
|
msg.Conversation.SelectedRoot = msg
|
||||||
err = a.Ctx.Store.UpdateConversation(msg.Conversation)
|
err = a.Ctx.Conversations.UpdateConversation(msg.Conversation)
|
||||||
} else {
|
} else {
|
||||||
msg.Parent.SelectedReply = msg
|
msg.Parent.SelectedReply = msg
|
||||||
err = a.Ctx.Store.UpdateMessage(msg.Parent)
|
err = a.Ctx.Conversations.UpdateMessage(msg.Parent)
|
||||||
}
|
}
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("Could not update selected message: %v", err)
|
return nil, fmt.Errorf("Could not update selected message: %v", err)
|
||||||
@ -150,11 +141,11 @@ func (a *AppModel) CloneMessage(message api.Message, selected bool) (*api.Messag
|
|||||||
return msg, nil
|
return msg, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (a *AppModel) UpdateMessageContent(message *api.Message) error {
|
func (a *AppModel) UpdateMessageContent(message *conversation.Message) error {
|
||||||
return a.Ctx.Store.UpdateMessage(message)
|
return a.Ctx.Conversations.UpdateMessage(message)
|
||||||
}
|
}
|
||||||
|
|
||||||
func cycleSelectedMessage(selected *api.Message, choices []api.Message, dir MessageCycleDirection) (*api.Message, error) {
|
func cycleSelectedMessage(selected *conversation.Message, choices []conversation.Message, dir MessageCycleDirection) (*conversation.Message, error) {
|
||||||
currentIndex := -1
|
currentIndex := -1
|
||||||
for i, reply := range choices {
|
for i, reply := range choices {
|
||||||
if reply.ID == selected.ID {
|
if reply.ID == selected.ID {
|
||||||
@ -176,25 +167,25 @@ func cycleSelectedMessage(selected *api.Message, choices []api.Message, dir Mess
|
|||||||
return &choices[next], nil
|
return &choices[next], nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (a *AppModel) CycleSelectedRoot(conv *api.Conversation, rootMessages []api.Message, dir MessageCycleDirection) (*api.Message, error) {
|
func (a *AppModel) CycleSelectedRoot(conv *conversation.Conversation, dir MessageCycleDirection) (*conversation.Message, error) {
|
||||||
if len(rootMessages) < 2 {
|
if len(conv.RootMessages) < 2 {
|
||||||
return nil, nil
|
return nil, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
nextRoot, err := cycleSelectedMessage(conv.SelectedRoot, rootMessages, dir)
|
nextRoot, err := cycleSelectedMessage(conv.SelectedRoot, conv.RootMessages, dir)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
conv.SelectedRoot = nextRoot
|
conv.SelectedRoot = nextRoot
|
||||||
err = a.Ctx.Store.UpdateConversation(conv)
|
err = a.Ctx.Conversations.UpdateConversation(conv)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("Could not update conversation SelectedRoot: %v", err)
|
return nil, fmt.Errorf("Could not update conversation SelectedRoot: %v", err)
|
||||||
}
|
}
|
||||||
return nextRoot, nil
|
return nextRoot, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (a *AppModel) CycleSelectedReply(message *api.Message, dir MessageCycleDirection) (*api.Message, error) {
|
func (a *AppModel) CycleSelectedReply(message *conversation.Message, dir MessageCycleDirection) (*conversation.Message, error) {
|
||||||
if len(message.Replies) < 2 {
|
if len(message.Replies) < 2 {
|
||||||
return nil, nil
|
return nil, nil
|
||||||
}
|
}
|
||||||
@ -205,17 +196,17 @@ func (a *AppModel) CycleSelectedReply(message *api.Message, dir MessageCycleDire
|
|||||||
}
|
}
|
||||||
|
|
||||||
message.SelectedReply = nextReply
|
message.SelectedReply = nextReply
|
||||||
err = a.Ctx.Store.UpdateMessage(message)
|
err = a.Ctx.Conversations.UpdateMessage(message)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("Could not update message SelectedReply: %v", err)
|
return nil, fmt.Errorf("Could not update message SelectedReply: %v", err)
|
||||||
}
|
}
|
||||||
return nextReply, nil
|
return nextReply, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (a *AppModel) PersistConversation(conversation *api.Conversation, messages []api.Message) (*api.Conversation, []api.Message, error) {
|
func (a *AppModel) PersistConversation(conversation *conversation.Conversation, messages []conversation.Message) (*conversation.Conversation, []conversation.Message, error) {
|
||||||
var err error
|
var err error
|
||||||
if conversation == nil || conversation.ID == 0 {
|
if conversation == nil || conversation.ID == 0 {
|
||||||
conversation, messages, err = a.Ctx.Store.StartConversation(messages...)
|
conversation, messages, err = a.Ctx.Conversations.StartConversation(messages...)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, nil, fmt.Errorf("Could not start new conversation: %v", err)
|
return nil, nil, fmt.Errorf("Could not start new conversation: %v", err)
|
||||||
}
|
}
|
||||||
@ -224,12 +215,12 @@ func (a *AppModel) PersistConversation(conversation *api.Conversation, messages
|
|||||||
|
|
||||||
for i := range messages {
|
for i := range messages {
|
||||||
if messages[i].ID > 0 {
|
if messages[i].ID > 0 {
|
||||||
err := a.Ctx.Store.UpdateMessage(&messages[i])
|
err := a.Ctx.Conversations.UpdateMessage(&messages[i])
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, nil, err
|
return nil, nil, err
|
||||||
}
|
}
|
||||||
} else if i > 0 {
|
} else if i > 0 {
|
||||||
saved, err := a.Ctx.Store.Reply(&messages[i-1], messages[i])
|
saved, err := a.Ctx.Conversations.Reply(&messages[i-1], messages[i])
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, nil, err
|
return nil, nil, err
|
||||||
}
|
}
|
||||||
@ -251,10 +242,10 @@ func (a *AppModel) ExecuteToolCalls(toolCalls []api.ToolCall) ([]api.ToolResult,
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (a *AppModel) Prompt(
|
func (a *AppModel) Prompt(
|
||||||
messages []api.Message,
|
messages []conversation.Message,
|
||||||
chatReplyChunks chan provider.Chunk,
|
chatReplyChunks chan provider.Chunk,
|
||||||
stopSignal chan struct{},
|
stopSignal chan struct{},
|
||||||
) (*api.Message, error) {
|
) (*conversation.Message, error) {
|
||||||
model, _, p, err := a.Ctx.GetModelProvider(a.Model, a.ProviderName)
|
model, _, p, err := a.Ctx.GetModelProvider(a.Model, a.ProviderName)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
@ -280,11 +271,14 @@ func (a *AppModel) Prompt(
|
|||||||
}()
|
}()
|
||||||
|
|
||||||
msg, err := p.CreateChatCompletionStream(
|
msg, err := p.CreateChatCompletionStream(
|
||||||
ctx, params, messages, chatReplyChunks,
|
ctx, params, conversation.MessagesToAPI(messages), chatReplyChunks,
|
||||||
)
|
)
|
||||||
|
|
||||||
if msg != nil {
|
if msg != nil {
|
||||||
|
msg := conversation.MessageFromAPI(*msg)
|
||||||
msg.Metadata.GenerationProvider = &a.ProviderName
|
msg.Metadata.GenerationProvider = &a.ProviderName
|
||||||
msg.Metadata.GenerationModel = &a.Model
|
msg.Metadata.GenerationModel = &a.Model
|
||||||
|
return &msg, err
|
||||||
}
|
}
|
||||||
return msg, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
@ -3,7 +3,7 @@ package tui
|
|||||||
import (
|
import (
|
||||||
"fmt"
|
"fmt"
|
||||||
|
|
||||||
"git.mlow.ca/mlow/lmcli/pkg/api"
|
"git.mlow.ca/mlow/lmcli/pkg/conversation"
|
||||||
"git.mlow.ca/mlow/lmcli/pkg/lmcli"
|
"git.mlow.ca/mlow/lmcli/pkg/lmcli"
|
||||||
"git.mlow.ca/mlow/lmcli/pkg/tui/model"
|
"git.mlow.ca/mlow/lmcli/pkg/tui/model"
|
||||||
"git.mlow.ca/mlow/lmcli/pkg/tui/shared"
|
"git.mlow.ca/mlow/lmcli/pkg/tui/shared"
|
||||||
@ -130,13 +130,13 @@ func (m *Model) View() string {
|
|||||||
}
|
}
|
||||||
|
|
||||||
type LaunchOptions struct {
|
type LaunchOptions struct {
|
||||||
InitialConversation *api.Conversation
|
InitialConversation *conversation.Conversation
|
||||||
InitialView shared.View
|
InitialView shared.View
|
||||||
}
|
}
|
||||||
|
|
||||||
type LaunchOption func(*LaunchOptions)
|
type LaunchOption func(*LaunchOptions)
|
||||||
|
|
||||||
func WithInitialConversation(conv *api.Conversation) LaunchOption {
|
func WithInitialConversation(conv *conversation.Conversation) LaunchOption {
|
||||||
return func(opts *LaunchOptions) {
|
return func(opts *LaunchOptions) {
|
||||||
opts.InitialConversation = conv
|
opts.InitialConversation = conv
|
||||||
}
|
}
|
||||||
|
@ -4,7 +4,8 @@ import (
|
|||||||
"time"
|
"time"
|
||||||
|
|
||||||
"git.mlow.ca/mlow/lmcli/pkg/api"
|
"git.mlow.ca/mlow/lmcli/pkg/api"
|
||||||
"git.mlow.ca/mlow/lmcli/pkg/api/provider"
|
"git.mlow.ca/mlow/lmcli/pkg/provider"
|
||||||
|
"git.mlow.ca/mlow/lmcli/pkg/conversation"
|
||||||
"git.mlow.ca/mlow/lmcli/pkg/tui/model"
|
"git.mlow.ca/mlow/lmcli/pkg/tui/model"
|
||||||
"github.com/charmbracelet/bubbles/cursor"
|
"github.com/charmbracelet/bubbles/cursor"
|
||||||
"github.com/charmbracelet/bubbles/spinner"
|
"github.com/charmbracelet/bubbles/spinner"
|
||||||
@ -20,14 +21,12 @@ type (
|
|||||||
msgConversationTitleGenerated string
|
msgConversationTitleGenerated string
|
||||||
// sent when the conversation has been persisted, triggers a reload of contents
|
// sent when the conversation has been persisted, triggers a reload of contents
|
||||||
msgConversationPersisted struct {
|
msgConversationPersisted struct {
|
||||||
isNew bool
|
conversation *conversation.Conversation
|
||||||
conversation *api.Conversation
|
messages []conversation.Message
|
||||||
messages []api.Message
|
|
||||||
}
|
}
|
||||||
// sent when a conversation's messages are laoded
|
// sent when a conversation's messages are laoded
|
||||||
msgConversationMessagesLoaded struct {
|
msgConversationMessagesLoaded struct {
|
||||||
messages []api.Message
|
messages []conversation.Message
|
||||||
rootMessages []api.Message
|
|
||||||
}
|
}
|
||||||
// a special case of common.MsgError that stops the response waiting animation
|
// a special case of common.MsgError that stops the response waiting animation
|
||||||
msgChatResponseError struct {
|
msgChatResponseError struct {
|
||||||
@ -36,19 +35,19 @@ type (
|
|||||||
// sent on each chunk received from LLM
|
// sent on each chunk received from LLM
|
||||||
msgChatResponseChunk provider.Chunk
|
msgChatResponseChunk provider.Chunk
|
||||||
// sent on each completed reply
|
// sent on each completed reply
|
||||||
msgChatResponse *api.Message
|
msgChatResponse *conversation.Message
|
||||||
// sent when the response is canceled
|
// sent when the response is canceled
|
||||||
msgChatResponseCanceled struct{}
|
msgChatResponseCanceled struct{}
|
||||||
// sent when results from a tool call are returned
|
// sent when results from a tool call are returned
|
||||||
msgToolResults []api.ToolResult
|
msgToolResults []api.ToolResult
|
||||||
// sent when the given message is made the new selected reply of its parent
|
// sent when the given message is made the new selected reply of its parent
|
||||||
msgSelectedReplyCycled *api.Message
|
msgSelectedReplyCycled *conversation.Message
|
||||||
// sent when the given message is made the new selected root of the current conversation
|
// sent when the given message is made the new selected root of the current conversation
|
||||||
msgSelectedRootCycled *api.Message
|
msgSelectedRootCycled *conversation.Message
|
||||||
// sent when a message's contents are updated and saved
|
// sent when a message's contents are updated and saved
|
||||||
msgMessageUpdated *api.Message
|
msgMessageUpdated *conversation.Message
|
||||||
// sent when a message is cloned, with the cloned message
|
// sent when a message is cloned, with the cloned message
|
||||||
msgMessageCloned *api.Message
|
msgMessageCloned *conversation.Message
|
||||||
)
|
)
|
||||||
|
|
||||||
type focusState int
|
type focusState int
|
||||||
@ -84,7 +83,7 @@ type Model struct {
|
|||||||
selectedMessage int
|
selectedMessage int
|
||||||
editorTarget editorTarget
|
editorTarget editorTarget
|
||||||
stopSignal chan struct{}
|
stopSignal chan struct{}
|
||||||
replyChan chan api.Message
|
replyChan chan conversation.Message
|
||||||
chatReplyChunks chan provider.Chunk
|
chatReplyChunks chan provider.Chunk
|
||||||
persistence bool // whether we will save new messages in the conversation
|
persistence bool // whether we will save new messages in the conversation
|
||||||
|
|
||||||
@ -137,7 +136,7 @@ func Chat(app *model.AppModel) *Model {
|
|||||||
persistence: true,
|
persistence: true,
|
||||||
|
|
||||||
stopSignal: make(chan struct{}),
|
stopSignal: make(chan struct{}),
|
||||||
replyChan: make(chan api.Message),
|
replyChan: make(chan conversation.Message),
|
||||||
chatReplyChunks: make(chan provider.Chunk),
|
chatReplyChunks: make(chan provider.Chunk),
|
||||||
|
|
||||||
wrap: true,
|
wrap: true,
|
||||||
|
@ -4,6 +4,7 @@ import (
|
|||||||
"time"
|
"time"
|
||||||
|
|
||||||
"git.mlow.ca/mlow/lmcli/pkg/api"
|
"git.mlow.ca/mlow/lmcli/pkg/api"
|
||||||
|
"git.mlow.ca/mlow/lmcli/pkg/conversation"
|
||||||
"git.mlow.ca/mlow/lmcli/pkg/tui/model"
|
"git.mlow.ca/mlow/lmcli/pkg/tui/model"
|
||||||
"git.mlow.ca/mlow/lmcli/pkg/tui/shared"
|
"git.mlow.ca/mlow/lmcli/pkg/tui/shared"
|
||||||
tea "github.com/charmbracelet/bubbletea"
|
tea "github.com/charmbracelet/bubbletea"
|
||||||
@ -21,13 +22,7 @@ func (m *Model) loadConversationMessages() tea.Cmd {
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return shared.AsMsgError(err)
|
return shared.AsMsgError(err)
|
||||||
}
|
}
|
||||||
rootMessages, err := m.App.LoadConversationRootMessages()
|
return msgConversationMessagesLoaded{messages}
|
||||||
if err != nil {
|
|
||||||
return shared.AsMsgError(err)
|
|
||||||
}
|
|
||||||
return msgConversationMessagesLoaded{
|
|
||||||
messages, rootMessages,
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -41,7 +36,7 @@ func (m *Model) generateConversationTitle() tea.Cmd {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *Model) updateConversationTitle(conversation *api.Conversation) tea.Cmd {
|
func (m *Model) updateConversationTitle(conversation *conversation.Conversation) tea.Cmd {
|
||||||
return func() tea.Msg {
|
return func() tea.Msg {
|
||||||
err := m.App.UpdateConversationTitle(conversation)
|
err := m.App.UpdateConversationTitle(conversation)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@ -51,7 +46,7 @@ func (m *Model) updateConversationTitle(conversation *api.Conversation) tea.Cmd
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *Model) cloneMessage(message api.Message, selected bool) tea.Cmd {
|
func (m *Model) cloneMessage(message conversation.Message, selected bool) tea.Cmd {
|
||||||
return func() tea.Msg {
|
return func() tea.Msg {
|
||||||
msg, err := m.App.CloneMessage(message, selected)
|
msg, err := m.App.CloneMessage(message, selected)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@ -61,7 +56,7 @@ func (m *Model) cloneMessage(message api.Message, selected bool) tea.Cmd {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *Model) updateMessageContent(message *api.Message) tea.Cmd {
|
func (m *Model) updateMessageContent(message *conversation.Message) tea.Cmd {
|
||||||
return func() tea.Msg {
|
return func() tea.Msg {
|
||||||
err := m.App.UpdateMessageContent(message)
|
err := m.App.UpdateMessageContent(message)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@ -71,14 +66,13 @@ func (m *Model) updateMessageContent(message *api.Message) tea.Cmd {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *Model) cycleSelectedRoot(conv *api.Conversation, dir model.MessageCycleDirection) tea.Cmd {
|
func (m *Model) cycleSelectedRoot(conv *conversation.Conversation, dir model.MessageCycleDirection) tea.Cmd {
|
||||||
if len(m.App.RootMessages) < 2 {
|
if len(conv.RootMessages) < 2 {
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
return func() tea.Msg {
|
return func() tea.Msg {
|
||||||
nextRoot, err := m.App.CycleSelectedRoot(conv, m.App.RootMessages, dir)
|
nextRoot, err := m.App.CycleSelectedRoot(conv, dir)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return shared.WrapError(err)
|
return shared.WrapError(err)
|
||||||
}
|
}
|
||||||
@ -86,7 +80,7 @@ func (m *Model) cycleSelectedRoot(conv *api.Conversation, dir model.MessageCycle
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *Model) cycleSelectedReply(message *api.Message, dir model.MessageCycleDirection) tea.Cmd {
|
func (m *Model) cycleSelectedReply(message *conversation.Message, dir model.MessageCycleDirection) tea.Cmd {
|
||||||
if len(message.Replies) < 2 {
|
if len(message.Replies) < 2 {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
@ -106,7 +100,7 @@ func (m *Model) persistConversation() tea.Cmd {
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return shared.AsMsgError(err)
|
return shared.AsMsgError(err)
|
||||||
}
|
}
|
||||||
return msgConversationPersisted{conversation.ID == 0, conversation, messages}
|
return msgConversationPersisted{conversation, messages}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -5,6 +5,7 @@ import (
|
|||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
"git.mlow.ca/mlow/lmcli/pkg/api"
|
"git.mlow.ca/mlow/lmcli/pkg/api"
|
||||||
|
"git.mlow.ca/mlow/lmcli/pkg/conversation"
|
||||||
"git.mlow.ca/mlow/lmcli/pkg/tui/model"
|
"git.mlow.ca/mlow/lmcli/pkg/tui/model"
|
||||||
"git.mlow.ca/mlow/lmcli/pkg/tui/shared"
|
"git.mlow.ca/mlow/lmcli/pkg/tui/shared"
|
||||||
tuiutil "git.mlow.ca/mlow/lmcli/pkg/tui/util"
|
tuiutil "git.mlow.ca/mlow/lmcli/pkg/tui/util"
|
||||||
@ -70,12 +71,12 @@ func (m *Model) handleInput(msg tea.KeyMsg) tea.Cmd {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (m *Model) scrollSelection(dir int) {
|
func (m *Model) scrollSelection(dir int) {
|
||||||
if m.selectedMessage + dir < 0 || m.selectedMessage + dir >= len(m.App.Messages) {
|
if m.selectedMessage+dir < 0 || m.selectedMessage+dir >= len(m.App.Messages) {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
newIdx := m.selectedMessage
|
newIdx := m.selectedMessage
|
||||||
for i := newIdx + dir; i >= 0 && i < len(m.App.Messages); i += dir{
|
for i := newIdx + dir; i >= 0 && i < len(m.App.Messages); i += dir {
|
||||||
if !m.showDetails && m.App.Messages[i].Role.IsSystem() {
|
if !m.showDetails && m.App.Messages[i].Role.IsSystem() {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
@ -175,7 +176,7 @@ func (m *Model) handleInputKey(msg tea.KeyMsg) tea.Cmd {
|
|||||||
return shared.WrapError(fmt.Errorf("Can't reply to a user message"))
|
return shared.WrapError(fmt.Errorf("Can't reply to a user message"))
|
||||||
}
|
}
|
||||||
|
|
||||||
m.addMessage(api.Message{
|
m.addMessage(conversation.Message{
|
||||||
Role: api.MessageRoleUser,
|
Role: api.MessageRoleUser,
|
||||||
Content: input,
|
Content: input,
|
||||||
})
|
})
|
||||||
|
@ -5,13 +5,14 @@ import (
|
|||||||
"time"
|
"time"
|
||||||
|
|
||||||
"git.mlow.ca/mlow/lmcli/pkg/api"
|
"git.mlow.ca/mlow/lmcli/pkg/api"
|
||||||
|
"git.mlow.ca/mlow/lmcli/pkg/conversation"
|
||||||
"git.mlow.ca/mlow/lmcli/pkg/tui/shared"
|
"git.mlow.ca/mlow/lmcli/pkg/tui/shared"
|
||||||
tuiutil "git.mlow.ca/mlow/lmcli/pkg/tui/util"
|
tuiutil "git.mlow.ca/mlow/lmcli/pkg/tui/util"
|
||||||
"github.com/charmbracelet/bubbles/cursor"
|
"github.com/charmbracelet/bubbles/cursor"
|
||||||
tea "github.com/charmbracelet/bubbletea"
|
tea "github.com/charmbracelet/bubbletea"
|
||||||
)
|
)
|
||||||
|
|
||||||
func (m *Model) setMessage(i int, msg api.Message) {
|
func (m *Model) setMessage(i int, msg conversation.Message) {
|
||||||
if i >= len(m.App.Messages) {
|
if i >= len(m.App.Messages) {
|
||||||
panic("i out of range")
|
panic("i out of range")
|
||||||
}
|
}
|
||||||
@ -19,7 +20,7 @@ func (m *Model) setMessage(i int, msg api.Message) {
|
|||||||
m.messageCache[i] = m.renderMessage(i)
|
m.messageCache[i] = m.renderMessage(i)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *Model) addMessage(msg api.Message) {
|
func (m *Model) addMessage(msg conversation.Message) {
|
||||||
m.App.Messages = append(m.App.Messages, msg)
|
m.App.Messages = append(m.App.Messages, msg)
|
||||||
m.messageCache = append(m.messageCache, m.renderMessage(len(m.App.Messages)-1))
|
m.messageCache = append(m.messageCache, m.renderMessage(len(m.App.Messages)-1))
|
||||||
}
|
}
|
||||||
@ -95,7 +96,6 @@ func (m *Model) Update(msg tea.Msg) (shared.ViewModel, tea.Cmd) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
case msgConversationMessagesLoaded:
|
case msgConversationMessagesLoaded:
|
||||||
m.App.RootMessages = msg.rootMessages
|
|
||||||
m.App.Messages = msg.messages
|
m.App.Messages = msg.messages
|
||||||
if m.selectedMessage == -1 {
|
if m.selectedMessage == -1 {
|
||||||
m.selectedMessage = len(msg.messages) - 1
|
m.selectedMessage = len(msg.messages) - 1
|
||||||
@ -117,7 +117,7 @@ func (m *Model) Update(msg tea.Msg) (shared.ViewModel, tea.Cmd) {
|
|||||||
m.setMessageContents(last, m.App.Messages[last].Content+msg.Content)
|
m.setMessageContents(last, m.App.Messages[last].Content+msg.Content)
|
||||||
} else {
|
} else {
|
||||||
// use chunk in a new message
|
// use chunk in a new message
|
||||||
m.addMessage(api.Message{
|
m.addMessage(conversation.Message{
|
||||||
Role: api.MessageRoleAssistant,
|
Role: api.MessageRoleAssistant,
|
||||||
Content: msg.Content,
|
Content: msg.Content,
|
||||||
})
|
})
|
||||||
@ -133,7 +133,7 @@ func (m *Model) Update(msg tea.Msg) (shared.ViewModel, tea.Cmd) {
|
|||||||
case msgChatResponse:
|
case msgChatResponse:
|
||||||
m.state = idle
|
m.state = idle
|
||||||
|
|
||||||
reply := (*api.Message)(msg)
|
reply := (*conversation.Message)(msg)
|
||||||
reply.Content = strings.TrimSpace(reply.Content)
|
reply.Content = strings.TrimSpace(reply.Content)
|
||||||
|
|
||||||
last := len(m.App.Messages) - 1
|
last := len(m.App.Messages) - 1
|
||||||
@ -181,9 +181,9 @@ func (m *Model) Update(msg tea.Msg) (shared.ViewModel, tea.Cmd) {
|
|||||||
panic("Previous message not a tool call, unexpected")
|
panic("Previous message not a tool call, unexpected")
|
||||||
}
|
}
|
||||||
|
|
||||||
m.addMessage(api.Message{
|
m.addMessage(conversation.Message{
|
||||||
Role: api.MessageRoleToolResult,
|
Role: api.MessageRoleToolResult,
|
||||||
ToolResults: api.ToolResults(msg),
|
ToolResults: conversation.ToolResults(msg),
|
||||||
})
|
})
|
||||||
|
|
||||||
if m.persistence {
|
if m.persistence {
|
||||||
@ -207,15 +207,11 @@ func (m *Model) Update(msg tea.Msg) (shared.ViewModel, tea.Cmd) {
|
|||||||
case msgConversationPersisted:
|
case msgConversationPersisted:
|
||||||
m.App.Conversation = msg.conversation
|
m.App.Conversation = msg.conversation
|
||||||
m.App.Messages = msg.messages
|
m.App.Messages = msg.messages
|
||||||
if msg.isNew {
|
|
||||||
m.App.RootMessages = []api.Message{m.App.Messages[0]}
|
|
||||||
}
|
|
||||||
m.rebuildMessageCache()
|
m.rebuildMessageCache()
|
||||||
m.updateContent()
|
m.updateContent()
|
||||||
case msgMessageCloned:
|
case msgMessageCloned:
|
||||||
if msg.Parent == nil {
|
if msg.Parent == nil {
|
||||||
m.App.Conversation = msg.Conversation
|
m.App.Conversation = msg.Conversation
|
||||||
m.App.RootMessages = append(m.App.RootMessages, *msg)
|
|
||||||
}
|
}
|
||||||
cmds = append(cmds, m.loadConversationMessages())
|
cmds = append(cmds, m.loadConversationMessages())
|
||||||
case msgSelectedRootCycled, msgSelectedReplyCycled, msgMessageUpdated:
|
case msgSelectedRootCycled, msgSelectedReplyCycled, msgMessageUpdated:
|
||||||
|
@ -6,6 +6,7 @@ import (
|
|||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
"git.mlow.ca/mlow/lmcli/pkg/api"
|
"git.mlow.ca/mlow/lmcli/pkg/api"
|
||||||
|
"git.mlow.ca/mlow/lmcli/pkg/conversation"
|
||||||
"git.mlow.ca/mlow/lmcli/pkg/tui/styles"
|
"git.mlow.ca/mlow/lmcli/pkg/tui/styles"
|
||||||
tuiutil "git.mlow.ca/mlow/lmcli/pkg/tui/util"
|
tuiutil "git.mlow.ca/mlow/lmcli/pkg/tui/util"
|
||||||
"github.com/charmbracelet/lipgloss"
|
"github.com/charmbracelet/lipgloss"
|
||||||
@ -44,7 +45,7 @@ var (
|
|||||||
footerStyle = lipgloss.NewStyle().Padding(0, 1)
|
footerStyle = lipgloss.NewStyle().Padding(0, 1)
|
||||||
)
|
)
|
||||||
|
|
||||||
func (m *Model) renderMessageHeading(i int, message *api.Message) string {
|
func (m *Model) renderMessageHeading(i int, message *conversation.Message) string {
|
||||||
friendly := message.Role.FriendlyRole()
|
friendly := message.Role.FriendlyRole()
|
||||||
style := systemStyle
|
style := systemStyle
|
||||||
|
|
||||||
@ -70,15 +71,15 @@ func (m *Model) renderMessageHeading(i int, message *api.Message) string {
|
|||||||
prefix = " "
|
prefix = " "
|
||||||
}
|
}
|
||||||
|
|
||||||
if i == 0 && len(m.App.RootMessages) > 1 && m.App.Conversation.SelectedRootID != nil {
|
if i == 0 && len(m.App.Conversation.RootMessages) > 1 && m.App.Conversation.SelectedRootID != nil {
|
||||||
selectedRootIndex := 0
|
selectedRootIndex := 0
|
||||||
for j, reply := range m.App.RootMessages {
|
for j, reply := range m.App.Conversation.RootMessages {
|
||||||
if reply.ID == *m.App.Conversation.SelectedRootID {
|
if reply.ID == *m.App.Conversation.SelectedRootID {
|
||||||
selectedRootIndex = j
|
selectedRootIndex = j
|
||||||
break
|
break
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
suffix += faintStyle.Render(fmt.Sprintf(" <%d/%d>", selectedRootIndex+1, len(m.App.RootMessages)))
|
suffix += faintStyle.Render(fmt.Sprintf(" <%d/%d>", selectedRootIndex+1, len(m.App.Conversation.RootMessages)))
|
||||||
}
|
}
|
||||||
if i > 0 && len(m.App.Messages[i-1].Replies) > 1 {
|
if i > 0 && len(m.App.Messages[i-1].Replies) > 1 {
|
||||||
// Find the selected reply index
|
// Find the selected reply index
|
||||||
@ -230,9 +231,9 @@ func (m *Model) conversationMessagesView() string {
|
|||||||
|
|
||||||
// Render a placeholder for the incoming assistant reply
|
// Render a placeholder for the incoming assistant reply
|
||||||
if m.state == pendingResponse && m.App.Messages[len(m.App.Messages)-1].Role != api.MessageRoleAssistant {
|
if m.state == pendingResponse && m.App.Messages[len(m.App.Messages)-1].Role != api.MessageRoleAssistant {
|
||||||
heading := m.renderMessageHeading(-1, &api.Message{
|
heading := m.renderMessageHeading(-1, &conversation.Message{
|
||||||
Role: api.MessageRoleAssistant,
|
Role: api.MessageRoleAssistant,
|
||||||
Metadata: api.MessageMeta{
|
Metadata: conversation.MessageMeta{
|
||||||
GenerationModel: &m.App.Model,
|
GenerationModel: &m.App.Model,
|
||||||
},
|
},
|
||||||
})
|
})
|
||||||
|
@ -5,7 +5,7 @@ import (
|
|||||||
"strings"
|
"strings"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"git.mlow.ca/mlow/lmcli/pkg/api"
|
"git.mlow.ca/mlow/lmcli/pkg/conversation"
|
||||||
"git.mlow.ca/mlow/lmcli/pkg/tui/bubbles"
|
"git.mlow.ca/mlow/lmcli/pkg/tui/bubbles"
|
||||||
"git.mlow.ca/mlow/lmcli/pkg/tui/model"
|
"git.mlow.ca/mlow/lmcli/pkg/tui/model"
|
||||||
"git.mlow.ca/mlow/lmcli/pkg/tui/shared"
|
"git.mlow.ca/mlow/lmcli/pkg/tui/shared"
|
||||||
@ -21,7 +21,7 @@ type (
|
|||||||
// sent when conversation list is loaded
|
// sent when conversation list is loaded
|
||||||
msgConversationsLoaded ([]model.LoadedConversation)
|
msgConversationsLoaded ([]model.LoadedConversation)
|
||||||
// sent when a conversation is selected
|
// sent when a conversation is selected
|
||||||
msgConversationSelected api.Conversation
|
msgConversationSelected conversation.Conversation
|
||||||
// sent when a conversation is deleted
|
// sent when a conversation is deleted
|
||||||
msgConversationDeleted struct{}
|
msgConversationDeleted struct{}
|
||||||
)
|
)
|
||||||
@ -154,7 +154,7 @@ func (m *Model) Update(msg tea.Msg) (shared.ViewModel, tea.Cmd) {
|
|||||||
case bubbles.MsgConfirmPromptAnswered:
|
case bubbles.MsgConfirmPromptAnswered:
|
||||||
m.confirmPrompt.Blur()
|
m.confirmPrompt.Blur()
|
||||||
if msg.Value {
|
if msg.Value {
|
||||||
conv, ok := msg.Payload.(api.Conversation)
|
conv, ok := msg.Payload.(conversation.Conversation)
|
||||||
if ok {
|
if ok {
|
||||||
cmds = append(cmds, m.deleteConversation(conv))
|
cmds = append(cmds, m.deleteConversation(conv))
|
||||||
}
|
}
|
||||||
@ -188,9 +188,9 @@ func (m *Model) loadConversations() tea.Cmd {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *Model) deleteConversation(conv api.Conversation) tea.Cmd {
|
func (m *Model) deleteConversation(conv conversation.Conversation) tea.Cmd {
|
||||||
return func() tea.Msg {
|
return func() tea.Msg {
|
||||||
err := m.App.Ctx.Store.DeleteConversation(&conv)
|
err := m.App.Ctx.Conversations.DeleteConversation(&conv)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return shared.AsMsgError(fmt.Errorf("Could not delete conversation: %v", err))
|
return shared.AsMsgError(fmt.Errorf("Could not delete conversation: %v", err))
|
||||||
}
|
}
|
||||||
|
Loading…
Reference in New Issue
Block a user