Package restructure and API changes, several fixes
- More emphasis on `api` package. It now holds database model structs from `lmcli/models` (which is now gone) as well as the tool spec, call, and result types. `tools.Tool` is now `api.ToolSpec`. `api.ChatCompletionClient` was renamed to `api.ChatCompletionProvider`. - Change ChatCompletion interface and implementations to no longer do automatic tool call recursion - they simply return a ToolCall message which the caller can decide what to do with (e.g. prompt for user confirmation before executing) - `api.ChatCompletionProvider` functions have had their ReplyCallback parameter removed, as now they only return a single reply. - Added a top-level `agent` package, moved the current built-in tools implementations under `agent/toolbox`. `tools.ExecuteToolCalls` is now `agent.ExecuteToolCalls`. - Fixed request context handling in openai, google, ollama (use `NewRequestWithContext`), cleaned up request cancellation in TUI - Fix tool call tui persistence bug (we were skipping message with empty content) - Now handle tool calling from TUI layer TODO: - Prompt users before executing tool calls - Automatically send tool results to the model (or make this toggleable)
This commit is contained in:
@@ -6,13 +6,12 @@ import (
|
||||
"path/filepath"
|
||||
"strings"
|
||||
|
||||
"git.mlow.ca/mlow/lmcli/pkg/agent"
|
||||
"git.mlow.ca/mlow/lmcli/pkg/api"
|
||||
"git.mlow.ca/mlow/lmcli/pkg/api/provider/anthropic"
|
||||
"git.mlow.ca/mlow/lmcli/pkg/api/provider/google"
|
||||
"git.mlow.ca/mlow/lmcli/pkg/api/provider/ollama"
|
||||
"git.mlow.ca/mlow/lmcli/pkg/api/provider/openai"
|
||||
"git.mlow.ca/mlow/lmcli/pkg/lmcli/model"
|
||||
"git.mlow.ca/mlow/lmcli/pkg/lmcli/tools"
|
||||
"git.mlow.ca/mlow/lmcli/pkg/util"
|
||||
"git.mlow.ca/mlow/lmcli/pkg/util/tty"
|
||||
"gorm.io/driver/sqlite"
|
||||
@@ -24,7 +23,7 @@ type Context struct {
|
||||
Store ConversationStore
|
||||
|
||||
Chroma *tty.ChromaHighlighter
|
||||
EnabledTools []model.Tool
|
||||
EnabledTools []api.ToolSpec
|
||||
|
||||
SystemPromptFile string
|
||||
}
|
||||
@@ -50,9 +49,9 @@ func NewContext() (*Context, error) {
|
||||
|
||||
chroma := tty.NewChromaHighlighter("markdown", *config.Chroma.Formatter, *config.Chroma.Style)
|
||||
|
||||
var enabledTools []model.Tool
|
||||
var enabledTools []api.ToolSpec
|
||||
for _, toolName := range config.Tools.EnabledTools {
|
||||
tool, ok := tools.AvailableTools[toolName]
|
||||
tool, ok := agent.AvailableTools[toolName]
|
||||
if ok {
|
||||
enabledTools = append(enabledTools, tool)
|
||||
}
|
||||
@@ -79,7 +78,7 @@ func (c *Context) GetModels() (models []string) {
|
||||
return
|
||||
}
|
||||
|
||||
func (c *Context) GetModelProvider(model string) (string, api.ChatCompletionClient, error) {
|
||||
func (c *Context) GetModelProvider(model string) (string, api.ChatCompletionProvider, error) {
|
||||
parts := strings.Split(model, "@")
|
||||
|
||||
var provider string
|
||||
|
||||
@@ -1,77 +0,0 @@
|
||||
package model
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"time"
|
||||
)
|
||||
|
||||
type MessageRole string
|
||||
|
||||
const (
|
||||
MessageRoleSystem MessageRole = "system"
|
||||
MessageRoleUser MessageRole = "user"
|
||||
MessageRoleAssistant MessageRole = "assistant"
|
||||
MessageRoleToolCall MessageRole = "tool_call"
|
||||
MessageRoleToolResult MessageRole = "tool_result"
|
||||
)
|
||||
|
||||
type Message struct {
|
||||
ID uint `gorm:"primaryKey"`
|
||||
ConversationID *uint `gorm:"index"`
|
||||
Conversation *Conversation `gorm:"foreignKey:ConversationID"`
|
||||
Content string
|
||||
Role MessageRole
|
||||
CreatedAt time.Time
|
||||
ToolCalls ToolCalls // a json array of tool calls (from the model)
|
||||
ToolResults ToolResults // a json array of tool results
|
||||
ParentID *uint
|
||||
Parent *Message `gorm:"foreignKey:ParentID"`
|
||||
Replies []Message `gorm:"foreignKey:ParentID"`
|
||||
|
||||
SelectedReplyID *uint
|
||||
SelectedReply *Message `gorm:"foreignKey:SelectedReplyID"`
|
||||
}
|
||||
|
||||
type Conversation struct {
|
||||
ID uint `gorm:"primaryKey"`
|
||||
ShortName sql.NullString
|
||||
Title string
|
||||
SelectedRootID *uint
|
||||
SelectedRoot *Message `gorm:"foreignKey:SelectedRootID"`
|
||||
}
|
||||
|
||||
type RequestParameters struct {
|
||||
Model string
|
||||
|
||||
MaxTokens int
|
||||
Temperature float32
|
||||
TopP float32
|
||||
|
||||
ToolBag []Tool
|
||||
}
|
||||
|
||||
func (m *MessageRole) IsAssistant() bool {
|
||||
switch *m {
|
||||
case MessageRoleAssistant, MessageRoleToolCall:
|
||||
return true
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// FriendlyRole returns a human friendly signifier for the message's role.
|
||||
func (m MessageRole) FriendlyRole() string {
|
||||
switch m {
|
||||
case MessageRoleUser:
|
||||
return "You"
|
||||
case MessageRoleSystem:
|
||||
return "System"
|
||||
case MessageRoleAssistant:
|
||||
return "Assistant"
|
||||
case MessageRoleToolCall:
|
||||
return "Tool Call"
|
||||
case MessageRoleToolResult:
|
||||
return "Tool Result"
|
||||
default:
|
||||
return string(m)
|
||||
}
|
||||
}
|
||||
@@ -1,98 +0,0 @@
|
||||
package model
|
||||
|
||||
import (
|
||||
"database/sql/driver"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
)
|
||||
|
||||
type Tool struct {
|
||||
Name string
|
||||
Description string
|
||||
Parameters []ToolParameter
|
||||
Impl func(*Tool, map[string]interface{}) (string, error)
|
||||
}
|
||||
|
||||
type ToolParameter struct {
|
||||
Name string `json:"name"`
|
||||
Type string `json:"type"` // "string", "integer", "boolean"
|
||||
Required bool `json:"required"`
|
||||
Description string `json:"description"`
|
||||
Enum []string `json:"enum,omitempty"`
|
||||
}
|
||||
|
||||
type ToolCall struct {
|
||||
ID string `json:"id" yaml:"-"`
|
||||
Name string `json:"name" yaml:"tool"`
|
||||
Parameters map[string]interface{} `json:"parameters" yaml:"parameters"`
|
||||
}
|
||||
|
||||
type ToolCalls []ToolCall
|
||||
|
||||
func (tc *ToolCalls) Scan(value any) (err error) {
|
||||
s := value.(string)
|
||||
if value == nil || s == "" {
|
||||
*tc = nil
|
||||
return
|
||||
}
|
||||
err = json.Unmarshal([]byte(s), tc)
|
||||
return
|
||||
}
|
||||
|
||||
func (tc ToolCalls) Value() (driver.Value, error) {
|
||||
if len(tc) == 0 {
|
||||
return "", nil
|
||||
}
|
||||
jsonBytes, err := json.Marshal(tc)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("Could not marshal ToolCalls to JSON: %v\n", err)
|
||||
}
|
||||
return string(jsonBytes), nil
|
||||
}
|
||||
|
||||
type ToolResult struct {
|
||||
ToolCallID string `json:"toolCallID" yaml:"-"`
|
||||
ToolName string `json:"toolName,omitempty" yaml:"tool"`
|
||||
Result string `json:"result,omitempty" yaml:"result"`
|
||||
}
|
||||
|
||||
type ToolResults []ToolResult
|
||||
|
||||
func (tr *ToolResults) Scan(value any) (err error) {
|
||||
s := value.(string)
|
||||
if value == nil || s == "" {
|
||||
*tr = nil
|
||||
return
|
||||
}
|
||||
err = json.Unmarshal([]byte(s), tr)
|
||||
return
|
||||
}
|
||||
|
||||
func (tr ToolResults) Value() (driver.Value, error) {
|
||||
if len(tr) == 0 {
|
||||
return "", nil
|
||||
}
|
||||
jsonBytes, err := json.Marshal([]ToolResult(tr))
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("Could not marshal ToolResults to JSON: %v\n", err)
|
||||
}
|
||||
return string(jsonBytes), nil
|
||||
}
|
||||
|
||||
type CallResult struct {
|
||||
Message string `json:"message"`
|
||||
Result any `json:"result,omitempty"`
|
||||
}
|
||||
|
||||
func (r CallResult) ToJson() (string, error) {
|
||||
if r.Message == "" {
|
||||
// When message not supplied, assume success
|
||||
r.Message = "success"
|
||||
}
|
||||
|
||||
jsonBytes, err := json.Marshal(r)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("Could not marshal CallResult to JSON: %v\n", err)
|
||||
}
|
||||
return string(jsonBytes), nil
|
||||
}
|
||||
@@ -8,32 +8,32 @@ import (
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"git.mlow.ca/mlow/lmcli/pkg/lmcli/model"
|
||||
"git.mlow.ca/mlow/lmcli/pkg/api"
|
||||
sqids "github.com/sqids/sqids-go"
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
type ConversationStore interface {
|
||||
ConversationByShortName(shortName string) (*model.Conversation, error)
|
||||
ConversationByShortName(shortName string) (*api.Conversation, error)
|
||||
ConversationShortNameCompletions(search string) []string
|
||||
RootMessages(conversationID uint) ([]model.Message, error)
|
||||
LatestConversationMessages() ([]model.Message, error)
|
||||
RootMessages(conversationID uint) ([]api.Message, error)
|
||||
LatestConversationMessages() ([]api.Message, error)
|
||||
|
||||
StartConversation(messages ...model.Message) (*model.Conversation, []model.Message, error)
|
||||
UpdateConversation(conversation *model.Conversation) error
|
||||
DeleteConversation(conversation *model.Conversation) error
|
||||
CloneConversation(toClone model.Conversation) (*model.Conversation, uint, error)
|
||||
StartConversation(messages ...api.Message) (*api.Conversation, []api.Message, error)
|
||||
UpdateConversation(conversation *api.Conversation) error
|
||||
DeleteConversation(conversation *api.Conversation) error
|
||||
CloneConversation(toClone api.Conversation) (*api.Conversation, uint, error)
|
||||
|
||||
MessageByID(messageID uint) (*model.Message, error)
|
||||
MessageReplies(messageID uint) ([]model.Message, error)
|
||||
MessageByID(messageID uint) (*api.Message, error)
|
||||
MessageReplies(messageID uint) ([]api.Message, error)
|
||||
|
||||
UpdateMessage(message *model.Message) error
|
||||
DeleteMessage(message *model.Message, prune bool) error
|
||||
CloneBranch(toClone model.Message) (*model.Message, uint, error)
|
||||
Reply(to *model.Message, messages ...model.Message) ([]model.Message, error)
|
||||
UpdateMessage(message *api.Message) error
|
||||
DeleteMessage(message *api.Message, prune bool) error
|
||||
CloneBranch(toClone api.Message) (*api.Message, uint, error)
|
||||
Reply(to *api.Message, messages ...api.Message) ([]api.Message, error)
|
||||
|
||||
PathToRoot(message *model.Message) ([]model.Message, error)
|
||||
PathToLeaf(message *model.Message) ([]model.Message, error)
|
||||
PathToRoot(message *api.Message) ([]api.Message, error)
|
||||
PathToLeaf(message *api.Message) ([]api.Message, error)
|
||||
}
|
||||
|
||||
type SQLStore struct {
|
||||
@@ -43,8 +43,8 @@ type SQLStore struct {
|
||||
|
||||
func NewSQLStore(db *gorm.DB) (*SQLStore, error) {
|
||||
models := []any{
|
||||
&model.Conversation{},
|
||||
&model.Message{},
|
||||
&api.Conversation{},
|
||||
&api.Message{},
|
||||
}
|
||||
|
||||
for _, x := range models {
|
||||
@@ -58,9 +58,9 @@ func NewSQLStore(db *gorm.DB) (*SQLStore, error) {
|
||||
return &SQLStore{db, _sqids}, nil
|
||||
}
|
||||
|
||||
func (s *SQLStore) createConversation() (*model.Conversation, error) {
|
||||
func (s *SQLStore) createConversation() (*api.Conversation, error) {
|
||||
// Create the new conversation
|
||||
c := &model.Conversation{}
|
||||
c := &api.Conversation{}
|
||||
err := s.db.Save(c).Error
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@@ -75,28 +75,28 @@ func (s *SQLStore) createConversation() (*model.Conversation, error) {
|
||||
return c, nil
|
||||
}
|
||||
|
||||
func (s *SQLStore) UpdateConversation(c *model.Conversation) error {
|
||||
func (s *SQLStore) UpdateConversation(c *api.Conversation) error {
|
||||
if c == nil || c.ID == 0 {
|
||||
return fmt.Errorf("Conversation is nil or invalid (missing ID)")
|
||||
}
|
||||
return s.db.Updates(c).Error
|
||||
}
|
||||
|
||||
func (s *SQLStore) DeleteConversation(c *model.Conversation) error {
|
||||
func (s *SQLStore) DeleteConversation(c *api.Conversation) error {
|
||||
// Delete messages first
|
||||
err := s.db.Where("conversation_id = ?", c.ID).Delete(&model.Message{}).Error
|
||||
err := s.db.Where("conversation_id = ?", c.ID).Delete(&api.Message{}).Error
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return s.db.Delete(c).Error
|
||||
}
|
||||
|
||||
func (s *SQLStore) DeleteMessage(message *model.Message, prune bool) error {
|
||||
func (s *SQLStore) DeleteMessage(message *api.Message, prune bool) error {
|
||||
panic("Not yet implemented")
|
||||
//return s.db.Delete(&message).Error
|
||||
}
|
||||
|
||||
func (s *SQLStore) UpdateMessage(m *model.Message) error {
|
||||
func (s *SQLStore) UpdateMessage(m *api.Message) error {
|
||||
if m == nil || m.ID == 0 {
|
||||
return fmt.Errorf("Message is nil or invalid (missing ID)")
|
||||
}
|
||||
@@ -104,7 +104,7 @@ func (s *SQLStore) UpdateMessage(m *model.Message) error {
|
||||
}
|
||||
|
||||
func (s *SQLStore) ConversationShortNameCompletions(shortName string) []string {
|
||||
var conversations []model.Conversation
|
||||
var conversations []api.Conversation
|
||||
// ignore error for completions
|
||||
s.db.Find(&conversations)
|
||||
completions := make([]string, 0, len(conversations))
|
||||
@@ -116,17 +116,17 @@ func (s *SQLStore) ConversationShortNameCompletions(shortName string) []string {
|
||||
return completions
|
||||
}
|
||||
|
||||
func (s *SQLStore) ConversationByShortName(shortName string) (*model.Conversation, error) {
|
||||
func (s *SQLStore) ConversationByShortName(shortName string) (*api.Conversation, error) {
|
||||
if shortName == "" {
|
||||
return nil, errors.New("shortName is empty")
|
||||
}
|
||||
var conversation model.Conversation
|
||||
var conversation api.Conversation
|
||||
err := s.db.Preload("SelectedRoot").Where("short_name = ?", shortName).Find(&conversation).Error
|
||||
return &conversation, err
|
||||
}
|
||||
|
||||
func (s *SQLStore) RootMessages(conversationID uint) ([]model.Message, error) {
|
||||
var rootMessages []model.Message
|
||||
func (s *SQLStore) RootMessages(conversationID uint) ([]api.Message, error) {
|
||||
var rootMessages []api.Message
|
||||
err := s.db.Where("conversation_id = ? AND parent_id IS NULL", conversationID).Find(&rootMessages).Error
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@@ -134,20 +134,20 @@ func (s *SQLStore) RootMessages(conversationID uint) ([]model.Message, error) {
|
||||
return rootMessages, nil
|
||||
}
|
||||
|
||||
func (s *SQLStore) MessageByID(messageID uint) (*model.Message, error) {
|
||||
var message model.Message
|
||||
func (s *SQLStore) MessageByID(messageID uint) (*api.Message, error) {
|
||||
var message api.Message
|
||||
err := s.db.Preload("Parent").Preload("Replies").Preload("SelectedReply").Where("id = ?", messageID).Find(&message).Error
|
||||
return &message, err
|
||||
}
|
||||
|
||||
func (s *SQLStore) MessageReplies(messageID uint) ([]model.Message, error) {
|
||||
var replies []model.Message
|
||||
func (s *SQLStore) MessageReplies(messageID uint) ([]api.Message, error) {
|
||||
var replies []api.Message
|
||||
err := s.db.Where("parent_id = ?", messageID).Find(&replies).Error
|
||||
return replies, err
|
||||
}
|
||||
|
||||
// StartConversation starts a new conversation with the provided messages
|
||||
func (s *SQLStore) StartConversation(messages ...model.Message) (*model.Conversation, []model.Message, error) {
|
||||
func (s *SQLStore) StartConversation(messages ...api.Message) (*api.Conversation, []api.Message, error) {
|
||||
if len(messages) == 0 {
|
||||
return nil, nil, fmt.Errorf("Must provide at least 1 message")
|
||||
}
|
||||
@@ -178,13 +178,13 @@ func (s *SQLStore) StartConversation(messages ...model.Message) (*model.Conversa
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
messages = append([]model.Message{messages[0]}, newMessages...)
|
||||
messages = append([]api.Message{messages[0]}, newMessages...)
|
||||
}
|
||||
return conversation, messages, nil
|
||||
}
|
||||
|
||||
// CloneConversation clones the given conversation and all of its root meesages
|
||||
func (s *SQLStore) CloneConversation(toClone model.Conversation) (*model.Conversation, uint, error) {
|
||||
func (s *SQLStore) CloneConversation(toClone api.Conversation) (*api.Conversation, uint, error) {
|
||||
rootMessages, err := s.RootMessages(toClone.ID)
|
||||
if err != nil {
|
||||
return nil, 0, err
|
||||
@@ -226,8 +226,8 @@ func (s *SQLStore) CloneConversation(toClone model.Conversation) (*model.Convers
|
||||
}
|
||||
|
||||
// Reply to a message with a series of messages (each following the next)
|
||||
func (s *SQLStore) Reply(to *model.Message, messages ...model.Message) ([]model.Message, error) {
|
||||
var savedMessages []model.Message
|
||||
func (s *SQLStore) Reply(to *api.Message, messages ...api.Message) ([]api.Message, error) {
|
||||
var savedMessages []api.Message
|
||||
|
||||
err := s.db.Transaction(func(tx *gorm.DB) error {
|
||||
currentParent := to
|
||||
@@ -262,7 +262,7 @@ func (s *SQLStore) Reply(to *model.Message, messages ...model.Message) ([]model.
|
||||
// CloneBranch returns a deep clone of the given message and its replies, returning
|
||||
// a new message object. The new message will be attached to the same parent as
|
||||
// the messageToClone
|
||||
func (s *SQLStore) CloneBranch(messageToClone model.Message) (*model.Message, uint, error) {
|
||||
func (s *SQLStore) CloneBranch(messageToClone api.Message) (*api.Message, uint, error) {
|
||||
newMessage := messageToClone
|
||||
newMessage.ID = 0
|
||||
newMessage.Replies = nil
|
||||
@@ -304,19 +304,19 @@ func (s *SQLStore) CloneBranch(messageToClone model.Message) (*model.Message, ui
|
||||
return &newMessage, replyCount, nil
|
||||
}
|
||||
|
||||
func fetchMessages(db *gorm.DB) ([]model.Message, error) {
|
||||
var messages []model.Message
|
||||
func fetchMessages(db *gorm.DB) ([]api.Message, error) {
|
||||
var messages []api.Message
|
||||
if err := db.Preload("Conversation").Find(&messages).Error; err != nil {
|
||||
return nil, fmt.Errorf("Could not fetch messages: %v", err)
|
||||
}
|
||||
|
||||
messageMap := make(map[uint]model.Message)
|
||||
messageMap := make(map[uint]api.Message)
|
||||
for i, message := range messages {
|
||||
messageMap[messages[i].ID] = message
|
||||
}
|
||||
|
||||
// Create a map to store replies by their parent ID
|
||||
repliesMap := make(map[uint][]model.Message)
|
||||
repliesMap := make(map[uint][]api.Message)
|
||||
for i, message := range messages {
|
||||
if messages[i].ParentID != nil {
|
||||
repliesMap[*messages[i].ParentID] = append(repliesMap[*messages[i].ParentID], message)
|
||||
@@ -326,7 +326,7 @@ func fetchMessages(db *gorm.DB) ([]model.Message, error) {
|
||||
// Assign replies, parent, and selected reply to each message
|
||||
for i := range messages {
|
||||
if replies, exists := repliesMap[messages[i].ID]; exists {
|
||||
messages[i].Replies = make([]model.Message, len(replies))
|
||||
messages[i].Replies = make([]api.Message, len(replies))
|
||||
for j, m := range replies {
|
||||
messages[i].Replies[j] = m
|
||||
}
|
||||
@@ -345,21 +345,21 @@ func fetchMessages(db *gorm.DB) ([]model.Message, error) {
|
||||
return messages, nil
|
||||
}
|
||||
|
||||
func (s *SQLStore) buildPath(message *model.Message, getNext func(*model.Message) *uint) ([]model.Message, error) {
|
||||
var messages []model.Message
|
||||
func (s *SQLStore) buildPath(message *api.Message, getNext func(*api.Message) *uint) ([]api.Message, error) {
|
||||
var messages []api.Message
|
||||
messages, err := fetchMessages(s.db.Where("conversation_id = ?", message.ConversationID))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Create a map to store messages by their ID
|
||||
messageMap := make(map[uint]*model.Message)
|
||||
messageMap := make(map[uint]*api.Message)
|
||||
for i := range messages {
|
||||
messageMap[messages[i].ID] = &messages[i]
|
||||
}
|
||||
|
||||
// Build the path
|
||||
var path []model.Message
|
||||
var path []api.Message
|
||||
nextID := &message.ID
|
||||
|
||||
for {
|
||||
@@ -382,12 +382,12 @@ func (s *SQLStore) buildPath(message *model.Message, getNext func(*model.Message
|
||||
// PathToRoot traverses the provided message's Parent until reaching the tree
|
||||
// root and returns a slice of all messages traversed in chronological order
|
||||
// (starting with the root and ending with the message provided)
|
||||
func (s *SQLStore) PathToRoot(message *model.Message) ([]model.Message, error) {
|
||||
func (s *SQLStore) PathToRoot(message *api.Message) ([]api.Message, error) {
|
||||
if message == nil || message.ID <= 0 {
|
||||
return nil, fmt.Errorf("Message is nil or has invalid ID")
|
||||
}
|
||||
|
||||
path, err := s.buildPath(message, func(m *model.Message) *uint {
|
||||
path, err := s.buildPath(message, func(m *api.Message) *uint {
|
||||
return m.ParentID
|
||||
})
|
||||
if err != nil {
|
||||
@@ -401,24 +401,24 @@ func (s *SQLStore) PathToRoot(message *model.Message) ([]model.Message, error) {
|
||||
// PathToLeaf traverses the provided message's SelectedReply until reaching a
|
||||
// tree leaf and returns a slice of all messages traversed in chronological
|
||||
// order (starting with the message provided and ending with the leaf)
|
||||
func (s *SQLStore) PathToLeaf(message *model.Message) ([]model.Message, error) {
|
||||
func (s *SQLStore) PathToLeaf(message *api.Message) ([]api.Message, error) {
|
||||
if message == nil || message.ID <= 0 {
|
||||
return nil, fmt.Errorf("Message is nil or has invalid ID")
|
||||
}
|
||||
|
||||
return s.buildPath(message, func(m *model.Message) *uint {
|
||||
return s.buildPath(message, func(m *api.Message) *uint {
|
||||
return m.SelectedReplyID
|
||||
})
|
||||
}
|
||||
|
||||
func (s *SQLStore) LatestConversationMessages() ([]model.Message, error) {
|
||||
var latestMessages []model.Message
|
||||
func (s *SQLStore) LatestConversationMessages() ([]api.Message, error) {
|
||||
var latestMessages []api.Message
|
||||
|
||||
subQuery := s.db.Model(&model.Message{}).
|
||||
subQuery := s.db.Model(&api.Message{}).
|
||||
Select("MAX(created_at) as max_created_at, conversation_id").
|
||||
Group("conversation_id")
|
||||
|
||||
err := s.db.Model(&model.Message{}).
|
||||
err := s.db.Model(&api.Message{}).
|
||||
Joins("JOIN (?) as sub on messages.conversation_id = sub.conversation_id AND messages.created_at = sub.max_created_at", subQuery).
|
||||
Group("messages.conversation_id").
|
||||
Order("created_at DESC").
|
||||
|
||||
@@ -1,142 +0,0 @@
|
||||
package tools
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
"git.mlow.ca/mlow/lmcli/pkg/lmcli/model"
|
||||
toolutil "git.mlow.ca/mlow/lmcli/pkg/lmcli/tools/util"
|
||||
)
|
||||
|
||||
const TREE_DESCRIPTION = `Retrieve a tree-like view of a directory's contents.
|
||||
|
||||
Use these results for your own reference in completing your task, they do not need to be shown to the user.
|
||||
|
||||
Example result:
|
||||
{
|
||||
"message": "success",
|
||||
"result": ".
|
||||
├── a_directory/
|
||||
│ ├── file1.txt (100 bytes)
|
||||
│ └── file2.txt (200 bytes)
|
||||
├── a_file.txt (123 bytes)
|
||||
└── another_file.txt (456 bytes)"
|
||||
}
|
||||
`
|
||||
|
||||
var DirTreeTool = model.Tool{
|
||||
Name: "dir_tree",
|
||||
Description: TREE_DESCRIPTION,
|
||||
Parameters: []model.ToolParameter{
|
||||
{
|
||||
Name: "relative_path",
|
||||
Type: "string",
|
||||
Description: "If set, display the tree starting from this path relative to the current one.",
|
||||
},
|
||||
{
|
||||
Name: "depth",
|
||||
Type: "integer",
|
||||
Description: "Depth of directory recursion. Defaults to 0 (no recursion), maximum of 5.",
|
||||
},
|
||||
},
|
||||
Impl: func(tool *model.Tool, args map[string]interface{}) (string, error) {
|
||||
var relativeDir string
|
||||
if tmp, ok := args["relative_path"]; ok {
|
||||
relativeDir, ok = tmp.(string)
|
||||
if !ok {
|
||||
return "", fmt.Errorf("expected string for relative_path, got %T", tmp)
|
||||
}
|
||||
}
|
||||
|
||||
var depth int = 0 // Default value if not provided
|
||||
if tmp, ok := args["depth"]; ok {
|
||||
switch v := tmp.(type) {
|
||||
case float64:
|
||||
depth = int(v)
|
||||
case string:
|
||||
var err error
|
||||
if depth, err = strconv.Atoi(v); err != nil {
|
||||
return "", fmt.Errorf("invalid `depth` value, expected integer but got string that cannot convert: %v", tmp)
|
||||
}
|
||||
depth = max(0, min(5, depth))
|
||||
default:
|
||||
return "", fmt.Errorf("expected int or string for max_depth, got %T", tmp)
|
||||
}
|
||||
}
|
||||
|
||||
result := tree(relativeDir, depth)
|
||||
ret, err := result.ToJson()
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("could not serialize result: %v", err)
|
||||
}
|
||||
return ret, nil
|
||||
},
|
||||
}
|
||||
|
||||
func tree(path string, depth int) model.CallResult {
|
||||
if path == "" {
|
||||
path = "."
|
||||
}
|
||||
ok, reason := toolutil.IsPathWithinCWD(path)
|
||||
if !ok {
|
||||
return model.CallResult{Message: reason}
|
||||
}
|
||||
|
||||
var treeOutput strings.Builder
|
||||
treeOutput.WriteString(path + "\n")
|
||||
err := buildTree(&treeOutput, path, "", depth)
|
||||
if err != nil {
|
||||
return model.CallResult{
|
||||
Message: err.Error(),
|
||||
}
|
||||
}
|
||||
|
||||
return model.CallResult{Result: treeOutput.String()}
|
||||
}
|
||||
|
||||
func buildTree(output *strings.Builder, path string, prefix string, depth int) error {
|
||||
files, err := os.ReadDir(path)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
for i, file := range files {
|
||||
if strings.HasPrefix(file.Name(), ".") {
|
||||
// Skip hidden files and directories
|
||||
continue
|
||||
}
|
||||
|
||||
isLast := i == len(files)-1
|
||||
var branch string
|
||||
if isLast {
|
||||
branch = "└── "
|
||||
} else {
|
||||
branch = "├── "
|
||||
}
|
||||
|
||||
info, _ := file.Info()
|
||||
size := info.Size()
|
||||
sizeStr := fmt.Sprintf(" (%d bytes)", size)
|
||||
|
||||
output.WriteString(prefix + branch + file.Name())
|
||||
if file.IsDir() {
|
||||
output.WriteString("/\n")
|
||||
if depth > 0 {
|
||||
var nextPrefix string
|
||||
if isLast {
|
||||
nextPrefix = prefix + " "
|
||||
} else {
|
||||
nextPrefix = prefix + "│ "
|
||||
}
|
||||
buildTree(output, filepath.Join(path, file.Name()), nextPrefix, depth-1)
|
||||
}
|
||||
} else {
|
||||
output.WriteString(sizeStr + "\n")
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
@@ -1,114 +0,0 @@
|
||||
package tools
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"os"
|
||||
"strings"
|
||||
|
||||
toolutil "git.mlow.ca/mlow/lmcli/pkg/lmcli/tools/util"
|
||||
"git.mlow.ca/mlow/lmcli/pkg/lmcli/model"
|
||||
)
|
||||
|
||||
const FILE_INSERT_LINES_DESCRIPTION = `Insert lines into a file, must specify path.
|
||||
|
||||
Make sure your inserts match the flow and indentation of surrounding content.`
|
||||
|
||||
var FileInsertLinesTool = model.Tool{
|
||||
Name: "file_insert_lines",
|
||||
Description: FILE_INSERT_LINES_DESCRIPTION,
|
||||
Parameters: []model.ToolParameter{
|
||||
{
|
||||
Name: "path",
|
||||
Type: "string",
|
||||
Description: "Path of the file to be modified, relative to the current working directory.",
|
||||
Required: true,
|
||||
},
|
||||
{
|
||||
Name: "position",
|
||||
Type: "integer",
|
||||
Description: `Which line to insert content *before*.`,
|
||||
Required: true,
|
||||
},
|
||||
{
|
||||
Name: "content",
|
||||
Type: "string",
|
||||
Description: `The content to insert.`,
|
||||
Required: true,
|
||||
},
|
||||
},
|
||||
Impl: func(tool *model.Tool, args map[string]interface{}) (string, error) {
|
||||
tmp, ok := args["path"]
|
||||
if !ok {
|
||||
return "", fmt.Errorf("path parameter to write_file was not included.")
|
||||
}
|
||||
path, ok := tmp.(string)
|
||||
if !ok {
|
||||
return "", fmt.Errorf("Invalid path in function arguments: %v", tmp)
|
||||
}
|
||||
var position int
|
||||
tmp, ok = args["position"]
|
||||
if ok {
|
||||
tmp, ok := tmp.(float64)
|
||||
if !ok {
|
||||
return "", fmt.Errorf("Invalid position in function arguments: %v", tmp)
|
||||
}
|
||||
position = int(tmp)
|
||||
}
|
||||
var content string
|
||||
tmp, ok = args["content"]
|
||||
if ok {
|
||||
content, ok = tmp.(string)
|
||||
if !ok {
|
||||
return "", fmt.Errorf("Invalid content in function arguments: %v", tmp)
|
||||
}
|
||||
}
|
||||
|
||||
result := fileInsertLines(path, position, content)
|
||||
ret, err := result.ToJson()
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("Could not serialize result: %v", err)
|
||||
}
|
||||
return ret, nil
|
||||
},
|
||||
}
|
||||
|
||||
func fileInsertLines(path string, position int, content string) model.CallResult {
|
||||
ok, reason := toolutil.IsPathWithinCWD(path)
|
||||
if !ok {
|
||||
return model.CallResult{Message: reason}
|
||||
}
|
||||
|
||||
// Read the existing file's content
|
||||
data, err := os.ReadFile(path)
|
||||
if err != nil {
|
||||
if !os.IsNotExist(err) {
|
||||
return model.CallResult{Message: fmt.Sprintf("Could not read path: %s", err.Error())}
|
||||
}
|
||||
_, err = os.Create(path)
|
||||
if err != nil {
|
||||
return model.CallResult{Message: fmt.Sprintf("Could not create new file: %s", err.Error())}
|
||||
}
|
||||
data = []byte{}
|
||||
}
|
||||
|
||||
if position < 1 {
|
||||
return model.CallResult{Message: "start_line cannot be less than 1"}
|
||||
}
|
||||
|
||||
lines := strings.Split(string(data), "\n")
|
||||
contentLines := strings.Split(strings.Trim(content, "\n"), "\n")
|
||||
|
||||
before := lines[:position-1]
|
||||
after := lines[position-1:]
|
||||
lines = append(before, append(contentLines, after...)...)
|
||||
|
||||
newContent := strings.Join(lines, "\n")
|
||||
|
||||
// Join the lines and write back to the file
|
||||
err = os.WriteFile(path, []byte(newContent), 0644)
|
||||
if err != nil {
|
||||
return model.CallResult{Message: fmt.Sprintf("Could not write to path: %s", err.Error())}
|
||||
}
|
||||
|
||||
return model.CallResult{Result: newContent}
|
||||
}
|
||||
@@ -1,133 +0,0 @@
|
||||
package tools
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"os"
|
||||
"strings"
|
||||
|
||||
"git.mlow.ca/mlow/lmcli/pkg/lmcli/model"
|
||||
toolutil "git.mlow.ca/mlow/lmcli/pkg/lmcli/tools/util"
|
||||
)
|
||||
|
||||
const FILE_REPLACE_LINES_DESCRIPTION = `Replace or remove a range of lines within a file, must specify path.
|
||||
|
||||
Useful for re-writing snippets/blocks of code or entire functions.
|
||||
|
||||
Plan your edits carefully and ensure any new content matches the flow and indentation of surrounding text.`
|
||||
|
||||
var FileReplaceLinesTool = model.Tool{
|
||||
Name: "file_replace_lines",
|
||||
Description: FILE_REPLACE_LINES_DESCRIPTION,
|
||||
Parameters: []model.ToolParameter{
|
||||
{
|
||||
Name: "path",
|
||||
Type: "string",
|
||||
Description: "Path of the file to be modified, relative to the current working directory.",
|
||||
Required: true,
|
||||
},
|
||||
{
|
||||
Name: "start_line",
|
||||
Type: "integer",
|
||||
Description: `Line number which specifies the start of the replacement range (inclusive).`,
|
||||
Required: true,
|
||||
},
|
||||
{
|
||||
Name: "end_line",
|
||||
Type: "integer",
|
||||
Description: `Line number which specifies the end of the replacement range (inclusive). If unset, range extends to end of file.`,
|
||||
},
|
||||
{
|
||||
Name: "content",
|
||||
Type: "string",
|
||||
Description: `Content to replace specified range. Omit to remove the specified range.`,
|
||||
},
|
||||
},
|
||||
Impl: func(tool *model.Tool, args map[string]interface{}) (string, error) {
|
||||
tmp, ok := args["path"]
|
||||
if !ok {
|
||||
return "", fmt.Errorf("path parameter to write_file was not included.")
|
||||
}
|
||||
path, ok := tmp.(string)
|
||||
if !ok {
|
||||
return "", fmt.Errorf("Invalid path in function arguments: %v", tmp)
|
||||
}
|
||||
var start_line int
|
||||
tmp, ok = args["start_line"]
|
||||
if ok {
|
||||
tmp, ok := tmp.(float64)
|
||||
if !ok {
|
||||
return "", fmt.Errorf("Invalid start_line in function arguments: %v", tmp)
|
||||
}
|
||||
start_line = int(tmp)
|
||||
}
|
||||
var end_line int
|
||||
tmp, ok = args["end_line"]
|
||||
if ok {
|
||||
tmp, ok := tmp.(float64)
|
||||
if !ok {
|
||||
return "", fmt.Errorf("Invalid end_line in function arguments: %v", tmp)
|
||||
}
|
||||
end_line = int(tmp)
|
||||
}
|
||||
var content string
|
||||
tmp, ok = args["content"]
|
||||
if ok {
|
||||
content, ok = tmp.(string)
|
||||
if !ok {
|
||||
return "", fmt.Errorf("Invalid content in function arguments: %v", tmp)
|
||||
}
|
||||
}
|
||||
|
||||
result := fileReplaceLines(path, start_line, end_line, content)
|
||||
ret, err := result.ToJson()
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("Could not serialize result: %v", err)
|
||||
}
|
||||
return ret, nil
|
||||
},
|
||||
}
|
||||
|
||||
func fileReplaceLines(path string, startLine int, endLine int, content string) model.CallResult {
|
||||
ok, reason := toolutil.IsPathWithinCWD(path)
|
||||
if !ok {
|
||||
return model.CallResult{Message: reason}
|
||||
}
|
||||
|
||||
// Read the existing file's content
|
||||
data, err := os.ReadFile(path)
|
||||
if err != nil {
|
||||
if !os.IsNotExist(err) {
|
||||
return model.CallResult{Message: fmt.Sprintf("Could not read path: %s", err.Error())}
|
||||
}
|
||||
_, err = os.Create(path)
|
||||
if err != nil {
|
||||
return model.CallResult{Message: fmt.Sprintf("Could not create new file: %s", err.Error())}
|
||||
}
|
||||
data = []byte{}
|
||||
}
|
||||
|
||||
if startLine < 1 {
|
||||
return model.CallResult{Message: "start_line cannot be less than 1"}
|
||||
}
|
||||
|
||||
lines := strings.Split(string(data), "\n")
|
||||
contentLines := strings.Split(strings.Trim(content, "\n"), "\n")
|
||||
|
||||
if endLine == 0 || endLine > len(lines) {
|
||||
endLine = len(lines)
|
||||
}
|
||||
|
||||
before := lines[:startLine-1]
|
||||
after := lines[endLine:]
|
||||
|
||||
lines = append(before, append(contentLines, after...)...)
|
||||
newContent := strings.Join(lines, "\n")
|
||||
|
||||
// Join the lines and write back to the file
|
||||
err = os.WriteFile(path, []byte(newContent), 0644)
|
||||
if err != nil {
|
||||
return model.CallResult{Message: fmt.Sprintf("Could not write to path: %s", err.Error())}
|
||||
}
|
||||
|
||||
return model.CallResult{Result: newContent}
|
||||
}
|
||||
@@ -1,100 +0,0 @@
|
||||
package tools
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
|
||||
"git.mlow.ca/mlow/lmcli/pkg/lmcli/model"
|
||||
toolutil "git.mlow.ca/mlow/lmcli/pkg/lmcli/tools/util"
|
||||
)
|
||||
|
||||
const READ_DIR_DESCRIPTION = `Return the contents of the CWD (current working directory).
|
||||
|
||||
Example result:
|
||||
{
|
||||
"message": "success",
|
||||
"result": [
|
||||
{"name": "a_file.txt", "type": "file", "size": 123},
|
||||
{"name": "a_directory/", "type": "dir", "size": 11},
|
||||
...
|
||||
]
|
||||
}
|
||||
|
||||
For files, size represents the size of the file, in bytes.
|
||||
For directories, size represents the number of entries in that directory.`
|
||||
|
||||
var ReadDirTool = model.Tool{
|
||||
Name: "read_dir",
|
||||
Description: READ_DIR_DESCRIPTION,
|
||||
Parameters: []model.ToolParameter{
|
||||
{
|
||||
Name: "relative_dir",
|
||||
Type: "string",
|
||||
Description: "If set, read the contents of a directory relative to the current one.",
|
||||
},
|
||||
},
|
||||
Impl: func(tool *model.Tool, args map[string]interface{}) (string, error) {
|
||||
var relativeDir string
|
||||
tmp, ok := args["relative_dir"]
|
||||
if ok {
|
||||
relativeDir, ok = tmp.(string)
|
||||
if !ok {
|
||||
return "", fmt.Errorf("Invalid relative_dir in function arguments: %v", tmp)
|
||||
}
|
||||
}
|
||||
result := readDir(relativeDir)
|
||||
ret, err := result.ToJson()
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("Could not serialize result: %v", err)
|
||||
}
|
||||
return ret, nil
|
||||
},
|
||||
}
|
||||
|
||||
func readDir(path string) model.CallResult {
|
||||
if path == "" {
|
||||
path = "."
|
||||
}
|
||||
ok, reason := toolutil.IsPathWithinCWD(path)
|
||||
if !ok {
|
||||
return model.CallResult{Message: reason}
|
||||
}
|
||||
|
||||
files, err := os.ReadDir(path)
|
||||
if err != nil {
|
||||
return model.CallResult{
|
||||
Message: err.Error(),
|
||||
}
|
||||
}
|
||||
|
||||
var dirContents []map[string]interface{}
|
||||
for _, f := range files {
|
||||
info, _ := f.Info()
|
||||
|
||||
name := f.Name()
|
||||
if strings.HasPrefix(name, ".") {
|
||||
// skip hidden files
|
||||
continue
|
||||
}
|
||||
|
||||
entryType := "file"
|
||||
size := info.Size()
|
||||
|
||||
if info.IsDir() {
|
||||
name += "/"
|
||||
entryType = "dir"
|
||||
subdirfiles, _ := os.ReadDir(filepath.Join(".", path, info.Name()))
|
||||
size = int64(len(subdirfiles))
|
||||
}
|
||||
|
||||
dirContents = append(dirContents, map[string]interface{}{
|
||||
"name": name,
|
||||
"type": entryType,
|
||||
"size": size,
|
||||
})
|
||||
}
|
||||
|
||||
return model.CallResult{Result: dirContents}
|
||||
}
|
||||
@@ -1,73 +0,0 @@
|
||||
package tools
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"os"
|
||||
"strings"
|
||||
|
||||
"git.mlow.ca/mlow/lmcli/pkg/lmcli/model"
|
||||
toolutil "git.mlow.ca/mlow/lmcli/pkg/lmcli/tools/util"
|
||||
)
|
||||
|
||||
const READ_FILE_DESCRIPTION = `Retrieve the contents of a text file relative to the current working directory.
|
||||
|
||||
Use the file contents for your own reference in completing your task, they do not need to be shown to the user.
|
||||
|
||||
Each line of the returned content is prefixed with its line number and a tab (\t).
|
||||
|
||||
Example result:
|
||||
{
|
||||
"message": "success",
|
||||
"result": "1\tthe contents\n2\tof the file\n"
|
||||
}`
|
||||
|
||||
var ReadFileTool = model.Tool{
|
||||
Name: "read_file",
|
||||
Description: READ_FILE_DESCRIPTION,
|
||||
Parameters: []model.ToolParameter{
|
||||
{
|
||||
Name: "path",
|
||||
Type: "string",
|
||||
Description: "Path to a file within the current working directory to read.",
|
||||
Required: true,
|
||||
},
|
||||
},
|
||||
|
||||
Impl: func(tool *model.Tool, args map[string]interface{}) (string, error) {
|
||||
tmp, ok := args["path"]
|
||||
if !ok {
|
||||
return "", fmt.Errorf("Path parameter to read_file was not included.")
|
||||
}
|
||||
path, ok := tmp.(string)
|
||||
if !ok {
|
||||
return "", fmt.Errorf("Invalid path in function arguments: %v", tmp)
|
||||
}
|
||||
result := readFile(path)
|
||||
ret, err := result.ToJson()
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("Could not serialize result: %v", err)
|
||||
}
|
||||
return ret, nil
|
||||
},
|
||||
}
|
||||
|
||||
func readFile(path string) model.CallResult {
|
||||
ok, reason := toolutil.IsPathWithinCWD(path)
|
||||
if !ok {
|
||||
return model.CallResult{Message: reason}
|
||||
}
|
||||
data, err := os.ReadFile(path)
|
||||
if err != nil {
|
||||
return model.CallResult{Message: fmt.Sprintf("Could not read path: %s", err.Error())}
|
||||
}
|
||||
|
||||
lines := strings.Split(string(data), "\n")
|
||||
content := strings.Builder{}
|
||||
for i, line := range lines {
|
||||
content.WriteString(fmt.Sprintf("%d\t%s\n", i+1, line))
|
||||
}
|
||||
|
||||
return model.CallResult{
|
||||
Result: content.String(),
|
||||
}
|
||||
}
|
||||
@@ -1,48 +0,0 @@
|
||||
package tools
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
"git.mlow.ca/mlow/lmcli/pkg/lmcli/model"
|
||||
)
|
||||
|
||||
var AvailableTools map[string]model.Tool = map[string]model.Tool{
|
||||
"dir_tree": DirTreeTool,
|
||||
"read_dir": ReadDirTool,
|
||||
"read_file": ReadFileTool,
|
||||
"write_file": WriteFileTool,
|
||||
"file_insert_lines": FileInsertLinesTool,
|
||||
"file_replace_lines": FileReplaceLinesTool,
|
||||
}
|
||||
|
||||
func ExecuteToolCalls(toolCalls []model.ToolCall, toolBag []model.Tool) ([]model.ToolResult, error) {
|
||||
var toolResults []model.ToolResult
|
||||
for _, toolCall := range toolCalls {
|
||||
var tool *model.Tool
|
||||
for _, available := range toolBag {
|
||||
if available.Name == toolCall.Name {
|
||||
tool = &available
|
||||
break
|
||||
}
|
||||
}
|
||||
if tool == nil {
|
||||
return nil, fmt.Errorf("Requested tool '%s' does not exist. Hallucination?", toolCall.Name)
|
||||
}
|
||||
|
||||
// Execute the tool
|
||||
result, err := tool.Impl(tool, toolCall.Parameters)
|
||||
if err != nil {
|
||||
// This can happen if the model missed or supplied invalid tool args
|
||||
return nil, fmt.Errorf("Tool '%s' error: %v\n", toolCall.Name, err)
|
||||
}
|
||||
|
||||
toolResult := model.ToolResult{
|
||||
ToolCallID: toolCall.ID,
|
||||
ToolName: toolCall.Name,
|
||||
Result: result,
|
||||
}
|
||||
|
||||
toolResults = append(toolResults, toolResult)
|
||||
}
|
||||
return toolResults, nil
|
||||
}
|
||||
@@ -1,67 +0,0 @@
|
||||
package util
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// isPathContained attempts to verify whether `path` is the same as or
|
||||
// contained within `directory`. It is overly cautious, returning false even if
|
||||
// `path` IS contained within `directory`, but the two paths use different
|
||||
// casing, and we happen to be on a case-insensitive filesystem.
|
||||
// This is ultimately to attempt to stop an LLM from going outside of where I
|
||||
// tell it to. Additional layers of security should be considered.. run in a
|
||||
// VM/container.
|
||||
func IsPathContained(directory string, path string) (bool, error) {
|
||||
// Clean and resolve symlinks for both paths
|
||||
path, err := filepath.Abs(path)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
|
||||
// check if path exists
|
||||
_, err = os.Stat(path)
|
||||
if err != nil {
|
||||
if !os.IsNotExist(err) {
|
||||
return false, fmt.Errorf("Could not stat path: %v", err)
|
||||
}
|
||||
} else {
|
||||
path, err = filepath.EvalSymlinks(path)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
}
|
||||
|
||||
directory, err = filepath.Abs(directory)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
directory, err = filepath.EvalSymlinks(directory)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
|
||||
// Case insensitive checks
|
||||
if !strings.EqualFold(path, directory) &&
|
||||
!strings.HasPrefix(strings.ToLower(path), strings.ToLower(directory)+string(os.PathSeparator)) {
|
||||
return false, nil
|
||||
}
|
||||
|
||||
return true, nil
|
||||
}
|
||||
|
||||
func IsPathWithinCWD(path string) (bool, string) {
|
||||
cwd, err := os.Getwd()
|
||||
if err != nil {
|
||||
return false, "Failed to determine current working directory"
|
||||
}
|
||||
if ok, err := IsPathContained(cwd, path); !ok {
|
||||
if err != nil {
|
||||
return false, fmt.Sprintf("Could not determine whether path '%s' is within the current working directory: %s", path, err.Error())
|
||||
}
|
||||
return false, fmt.Sprintf("Path '%s' is not within the current working directory", path)
|
||||
}
|
||||
return true, ""
|
||||
}
|
||||
@@ -1,71 +0,0 @@
|
||||
package tools
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"os"
|
||||
|
||||
"git.mlow.ca/mlow/lmcli/pkg/lmcli/model"
|
||||
toolutil "git.mlow.ca/mlow/lmcli/pkg/lmcli/tools/util"
|
||||
)
|
||||
|
||||
const WRITE_FILE_DESCRIPTION = `Write the provided contents to a file relative to the current working directory.
|
||||
|
||||
Example result:
|
||||
{
|
||||
"message": "success"
|
||||
}`
|
||||
|
||||
var WriteFileTool = model.Tool{
|
||||
Name: "write_file",
|
||||
Description: WRITE_FILE_DESCRIPTION,
|
||||
Parameters: []model.ToolParameter{
|
||||
{
|
||||
Name: "path",
|
||||
Type: "string",
|
||||
Description: "Path to a file within the current working directory to write to.",
|
||||
Required: true,
|
||||
},
|
||||
{
|
||||
Name: "content",
|
||||
Type: "string",
|
||||
Description: "The content to write to the file. Overwrites any existing content!",
|
||||
Required: true,
|
||||
},
|
||||
},
|
||||
Impl: func(t *model.Tool, args map[string]interface{}) (string, error) {
|
||||
tmp, ok := args["path"]
|
||||
if !ok {
|
||||
return "", fmt.Errorf("Path parameter to write_file was not included.")
|
||||
}
|
||||
path, ok := tmp.(string)
|
||||
if !ok {
|
||||
return "", fmt.Errorf("Invalid path in function arguments: %v", tmp)
|
||||
}
|
||||
tmp, ok = args["content"]
|
||||
if !ok {
|
||||
return "", fmt.Errorf("Content parameter to write_file was not included.")
|
||||
}
|
||||
content, ok := tmp.(string)
|
||||
if !ok {
|
||||
return "", fmt.Errorf("Invalid content in function arguments: %v", tmp)
|
||||
}
|
||||
result := writeFile(path, content)
|
||||
ret, err := result.ToJson()
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("Could not serialize result: %v", err)
|
||||
}
|
||||
return ret, nil
|
||||
},
|
||||
}
|
||||
|
||||
func writeFile(path string, content string) model.CallResult {
|
||||
ok, reason := toolutil.IsPathWithinCWD(path)
|
||||
if !ok {
|
||||
return model.CallResult{Message: reason}
|
||||
}
|
||||
err := os.WriteFile(path, []byte(content), 0644)
|
||||
if err != nil {
|
||||
return model.CallResult{Message: fmt.Sprintf("Could not write to path: %s", err.Error())}
|
||||
}
|
||||
return model.CallResult{}
|
||||
}
|
||||
Reference in New Issue
Block a user