Private
Public Access
1
0

Package restructure and API changes, several fixes

- More emphasis on `api` package. It now holds database model structs
  from `lmcli/models` (which is now gone) as well as the tool spec,
  call, and result types. `tools.Tool` is now `api.ToolSpec`.
  `api.ChatCompletionClient` was renamed to
  `api.ChatCompletionProvider`.

- Change ChatCompletion interface and implementations to no longer do
  automatic tool call recursion - they simply return a ToolCall message
  which the caller can decide what to do with (e.g. prompt for user
  confirmation before executing)

- `api.ChatCompletionProvider` functions have had their ReplyCallback
  parameter removed, as now they only return a single reply.

- Added a top-level `agent` package, moved the current built-in tools
  implementations under `agent/toolbox`. `tools.ExecuteToolCalls` is now
  `agent.ExecuteToolCalls`.

- Fixed request context handling in openai, google, ollama (use
  `NewRequestWithContext`), cleaned up request cancellation in TUI

- Fix tool call tui persistence bug (we were skipping message with empty
  content)

- Now handle tool calling from TUI layer

TODO:
- Prompt users before executing tool calls
- Automatically send tool results to the model (or make this toggleable)
This commit is contained in:
2024-06-12 08:35:07 +00:00
parent 85a2abbbf3
commit 3fde58b77d
35 changed files with 608 additions and 749 deletions

View File

@@ -6,13 +6,12 @@ import (
"path/filepath"
"strings"
"git.mlow.ca/mlow/lmcli/pkg/agent"
"git.mlow.ca/mlow/lmcli/pkg/api"
"git.mlow.ca/mlow/lmcli/pkg/api/provider/anthropic"
"git.mlow.ca/mlow/lmcli/pkg/api/provider/google"
"git.mlow.ca/mlow/lmcli/pkg/api/provider/ollama"
"git.mlow.ca/mlow/lmcli/pkg/api/provider/openai"
"git.mlow.ca/mlow/lmcli/pkg/lmcli/model"
"git.mlow.ca/mlow/lmcli/pkg/lmcli/tools"
"git.mlow.ca/mlow/lmcli/pkg/util"
"git.mlow.ca/mlow/lmcli/pkg/util/tty"
"gorm.io/driver/sqlite"
@@ -24,7 +23,7 @@ type Context struct {
Store ConversationStore
Chroma *tty.ChromaHighlighter
EnabledTools []model.Tool
EnabledTools []api.ToolSpec
SystemPromptFile string
}
@@ -50,9 +49,9 @@ func NewContext() (*Context, error) {
chroma := tty.NewChromaHighlighter("markdown", *config.Chroma.Formatter, *config.Chroma.Style)
var enabledTools []model.Tool
var enabledTools []api.ToolSpec
for _, toolName := range config.Tools.EnabledTools {
tool, ok := tools.AvailableTools[toolName]
tool, ok := agent.AvailableTools[toolName]
if ok {
enabledTools = append(enabledTools, tool)
}
@@ -79,7 +78,7 @@ func (c *Context) GetModels() (models []string) {
return
}
func (c *Context) GetModelProvider(model string) (string, api.ChatCompletionClient, error) {
func (c *Context) GetModelProvider(model string) (string, api.ChatCompletionProvider, error) {
parts := strings.Split(model, "@")
var provider string

View File

@@ -1,77 +0,0 @@
package model
import (
"database/sql"
"time"
)
type MessageRole string
const (
MessageRoleSystem MessageRole = "system"
MessageRoleUser MessageRole = "user"
MessageRoleAssistant MessageRole = "assistant"
MessageRoleToolCall MessageRole = "tool_call"
MessageRoleToolResult MessageRole = "tool_result"
)
type Message struct {
ID uint `gorm:"primaryKey"`
ConversationID *uint `gorm:"index"`
Conversation *Conversation `gorm:"foreignKey:ConversationID"`
Content string
Role MessageRole
CreatedAt time.Time
ToolCalls ToolCalls // a json array of tool calls (from the model)
ToolResults ToolResults // a json array of tool results
ParentID *uint
Parent *Message `gorm:"foreignKey:ParentID"`
Replies []Message `gorm:"foreignKey:ParentID"`
SelectedReplyID *uint
SelectedReply *Message `gorm:"foreignKey:SelectedReplyID"`
}
type Conversation struct {
ID uint `gorm:"primaryKey"`
ShortName sql.NullString
Title string
SelectedRootID *uint
SelectedRoot *Message `gorm:"foreignKey:SelectedRootID"`
}
type RequestParameters struct {
Model string
MaxTokens int
Temperature float32
TopP float32
ToolBag []Tool
}
func (m *MessageRole) IsAssistant() bool {
switch *m {
case MessageRoleAssistant, MessageRoleToolCall:
return true
}
return false
}
// FriendlyRole returns a human friendly signifier for the message's role.
func (m MessageRole) FriendlyRole() string {
switch m {
case MessageRoleUser:
return "You"
case MessageRoleSystem:
return "System"
case MessageRoleAssistant:
return "Assistant"
case MessageRoleToolCall:
return "Tool Call"
case MessageRoleToolResult:
return "Tool Result"
default:
return string(m)
}
}

View File

@@ -1,98 +0,0 @@
package model
import (
"database/sql/driver"
"encoding/json"
"fmt"
)
type Tool struct {
Name string
Description string
Parameters []ToolParameter
Impl func(*Tool, map[string]interface{}) (string, error)
}
type ToolParameter struct {
Name string `json:"name"`
Type string `json:"type"` // "string", "integer", "boolean"
Required bool `json:"required"`
Description string `json:"description"`
Enum []string `json:"enum,omitempty"`
}
type ToolCall struct {
ID string `json:"id" yaml:"-"`
Name string `json:"name" yaml:"tool"`
Parameters map[string]interface{} `json:"parameters" yaml:"parameters"`
}
type ToolCalls []ToolCall
func (tc *ToolCalls) Scan(value any) (err error) {
s := value.(string)
if value == nil || s == "" {
*tc = nil
return
}
err = json.Unmarshal([]byte(s), tc)
return
}
func (tc ToolCalls) Value() (driver.Value, error) {
if len(tc) == 0 {
return "", nil
}
jsonBytes, err := json.Marshal(tc)
if err != nil {
return "", fmt.Errorf("Could not marshal ToolCalls to JSON: %v\n", err)
}
return string(jsonBytes), nil
}
type ToolResult struct {
ToolCallID string `json:"toolCallID" yaml:"-"`
ToolName string `json:"toolName,omitempty" yaml:"tool"`
Result string `json:"result,omitempty" yaml:"result"`
}
type ToolResults []ToolResult
func (tr *ToolResults) Scan(value any) (err error) {
s := value.(string)
if value == nil || s == "" {
*tr = nil
return
}
err = json.Unmarshal([]byte(s), tr)
return
}
func (tr ToolResults) Value() (driver.Value, error) {
if len(tr) == 0 {
return "", nil
}
jsonBytes, err := json.Marshal([]ToolResult(tr))
if err != nil {
return "", fmt.Errorf("Could not marshal ToolResults to JSON: %v\n", err)
}
return string(jsonBytes), nil
}
type CallResult struct {
Message string `json:"message"`
Result any `json:"result,omitempty"`
}
func (r CallResult) ToJson() (string, error) {
if r.Message == "" {
// When message not supplied, assume success
r.Message = "success"
}
jsonBytes, err := json.Marshal(r)
if err != nil {
return "", fmt.Errorf("Could not marshal CallResult to JSON: %v\n", err)
}
return string(jsonBytes), nil
}

View File

@@ -8,32 +8,32 @@ import (
"strings"
"time"
"git.mlow.ca/mlow/lmcli/pkg/lmcli/model"
"git.mlow.ca/mlow/lmcli/pkg/api"
sqids "github.com/sqids/sqids-go"
"gorm.io/gorm"
)
type ConversationStore interface {
ConversationByShortName(shortName string) (*model.Conversation, error)
ConversationByShortName(shortName string) (*api.Conversation, error)
ConversationShortNameCompletions(search string) []string
RootMessages(conversationID uint) ([]model.Message, error)
LatestConversationMessages() ([]model.Message, error)
RootMessages(conversationID uint) ([]api.Message, error)
LatestConversationMessages() ([]api.Message, error)
StartConversation(messages ...model.Message) (*model.Conversation, []model.Message, error)
UpdateConversation(conversation *model.Conversation) error
DeleteConversation(conversation *model.Conversation) error
CloneConversation(toClone model.Conversation) (*model.Conversation, uint, error)
StartConversation(messages ...api.Message) (*api.Conversation, []api.Message, error)
UpdateConversation(conversation *api.Conversation) error
DeleteConversation(conversation *api.Conversation) error
CloneConversation(toClone api.Conversation) (*api.Conversation, uint, error)
MessageByID(messageID uint) (*model.Message, error)
MessageReplies(messageID uint) ([]model.Message, error)
MessageByID(messageID uint) (*api.Message, error)
MessageReplies(messageID uint) ([]api.Message, error)
UpdateMessage(message *model.Message) error
DeleteMessage(message *model.Message, prune bool) error
CloneBranch(toClone model.Message) (*model.Message, uint, error)
Reply(to *model.Message, messages ...model.Message) ([]model.Message, error)
UpdateMessage(message *api.Message) error
DeleteMessage(message *api.Message, prune bool) error
CloneBranch(toClone api.Message) (*api.Message, uint, error)
Reply(to *api.Message, messages ...api.Message) ([]api.Message, error)
PathToRoot(message *model.Message) ([]model.Message, error)
PathToLeaf(message *model.Message) ([]model.Message, error)
PathToRoot(message *api.Message) ([]api.Message, error)
PathToLeaf(message *api.Message) ([]api.Message, error)
}
type SQLStore struct {
@@ -43,8 +43,8 @@ type SQLStore struct {
func NewSQLStore(db *gorm.DB) (*SQLStore, error) {
models := []any{
&model.Conversation{},
&model.Message{},
&api.Conversation{},
&api.Message{},
}
for _, x := range models {
@@ -58,9 +58,9 @@ func NewSQLStore(db *gorm.DB) (*SQLStore, error) {
return &SQLStore{db, _sqids}, nil
}
func (s *SQLStore) createConversation() (*model.Conversation, error) {
func (s *SQLStore) createConversation() (*api.Conversation, error) {
// Create the new conversation
c := &model.Conversation{}
c := &api.Conversation{}
err := s.db.Save(c).Error
if err != nil {
return nil, err
@@ -75,28 +75,28 @@ func (s *SQLStore) createConversation() (*model.Conversation, error) {
return c, nil
}
func (s *SQLStore) UpdateConversation(c *model.Conversation) error {
func (s *SQLStore) UpdateConversation(c *api.Conversation) error {
if c == nil || c.ID == 0 {
return fmt.Errorf("Conversation is nil or invalid (missing ID)")
}
return s.db.Updates(c).Error
}
func (s *SQLStore) DeleteConversation(c *model.Conversation) error {
func (s *SQLStore) DeleteConversation(c *api.Conversation) error {
// Delete messages first
err := s.db.Where("conversation_id = ?", c.ID).Delete(&model.Message{}).Error
err := s.db.Where("conversation_id = ?", c.ID).Delete(&api.Message{}).Error
if err != nil {
return err
}
return s.db.Delete(c).Error
}
func (s *SQLStore) DeleteMessage(message *model.Message, prune bool) error {
func (s *SQLStore) DeleteMessage(message *api.Message, prune bool) error {
panic("Not yet implemented")
//return s.db.Delete(&message).Error
}
func (s *SQLStore) UpdateMessage(m *model.Message) error {
func (s *SQLStore) UpdateMessage(m *api.Message) error {
if m == nil || m.ID == 0 {
return fmt.Errorf("Message is nil or invalid (missing ID)")
}
@@ -104,7 +104,7 @@ func (s *SQLStore) UpdateMessage(m *model.Message) error {
}
func (s *SQLStore) ConversationShortNameCompletions(shortName string) []string {
var conversations []model.Conversation
var conversations []api.Conversation
// ignore error for completions
s.db.Find(&conversations)
completions := make([]string, 0, len(conversations))
@@ -116,17 +116,17 @@ func (s *SQLStore) ConversationShortNameCompletions(shortName string) []string {
return completions
}
func (s *SQLStore) ConversationByShortName(shortName string) (*model.Conversation, error) {
func (s *SQLStore) ConversationByShortName(shortName string) (*api.Conversation, error) {
if shortName == "" {
return nil, errors.New("shortName is empty")
}
var conversation model.Conversation
var conversation api.Conversation
err := s.db.Preload("SelectedRoot").Where("short_name = ?", shortName).Find(&conversation).Error
return &conversation, err
}
func (s *SQLStore) RootMessages(conversationID uint) ([]model.Message, error) {
var rootMessages []model.Message
func (s *SQLStore) RootMessages(conversationID uint) ([]api.Message, error) {
var rootMessages []api.Message
err := s.db.Where("conversation_id = ? AND parent_id IS NULL", conversationID).Find(&rootMessages).Error
if err != nil {
return nil, err
@@ -134,20 +134,20 @@ func (s *SQLStore) RootMessages(conversationID uint) ([]model.Message, error) {
return rootMessages, nil
}
func (s *SQLStore) MessageByID(messageID uint) (*model.Message, error) {
var message model.Message
func (s *SQLStore) MessageByID(messageID uint) (*api.Message, error) {
var message api.Message
err := s.db.Preload("Parent").Preload("Replies").Preload("SelectedReply").Where("id = ?", messageID).Find(&message).Error
return &message, err
}
func (s *SQLStore) MessageReplies(messageID uint) ([]model.Message, error) {
var replies []model.Message
func (s *SQLStore) MessageReplies(messageID uint) ([]api.Message, error) {
var replies []api.Message
err := s.db.Where("parent_id = ?", messageID).Find(&replies).Error
return replies, err
}
// StartConversation starts a new conversation with the provided messages
func (s *SQLStore) StartConversation(messages ...model.Message) (*model.Conversation, []model.Message, error) {
func (s *SQLStore) StartConversation(messages ...api.Message) (*api.Conversation, []api.Message, error) {
if len(messages) == 0 {
return nil, nil, fmt.Errorf("Must provide at least 1 message")
}
@@ -178,13 +178,13 @@ func (s *SQLStore) StartConversation(messages ...model.Message) (*model.Conversa
if err != nil {
return nil, nil, err
}
messages = append([]model.Message{messages[0]}, newMessages...)
messages = append([]api.Message{messages[0]}, newMessages...)
}
return conversation, messages, nil
}
// CloneConversation clones the given conversation and all of its root meesages
func (s *SQLStore) CloneConversation(toClone model.Conversation) (*model.Conversation, uint, error) {
func (s *SQLStore) CloneConversation(toClone api.Conversation) (*api.Conversation, uint, error) {
rootMessages, err := s.RootMessages(toClone.ID)
if err != nil {
return nil, 0, err
@@ -226,8 +226,8 @@ func (s *SQLStore) CloneConversation(toClone model.Conversation) (*model.Convers
}
// Reply to a message with a series of messages (each following the next)
func (s *SQLStore) Reply(to *model.Message, messages ...model.Message) ([]model.Message, error) {
var savedMessages []model.Message
func (s *SQLStore) Reply(to *api.Message, messages ...api.Message) ([]api.Message, error) {
var savedMessages []api.Message
err := s.db.Transaction(func(tx *gorm.DB) error {
currentParent := to
@@ -262,7 +262,7 @@ func (s *SQLStore) Reply(to *model.Message, messages ...model.Message) ([]model.
// CloneBranch returns a deep clone of the given message and its replies, returning
// a new message object. The new message will be attached to the same parent as
// the messageToClone
func (s *SQLStore) CloneBranch(messageToClone model.Message) (*model.Message, uint, error) {
func (s *SQLStore) CloneBranch(messageToClone api.Message) (*api.Message, uint, error) {
newMessage := messageToClone
newMessage.ID = 0
newMessage.Replies = nil
@@ -304,19 +304,19 @@ func (s *SQLStore) CloneBranch(messageToClone model.Message) (*model.Message, ui
return &newMessage, replyCount, nil
}
func fetchMessages(db *gorm.DB) ([]model.Message, error) {
var messages []model.Message
func fetchMessages(db *gorm.DB) ([]api.Message, error) {
var messages []api.Message
if err := db.Preload("Conversation").Find(&messages).Error; err != nil {
return nil, fmt.Errorf("Could not fetch messages: %v", err)
}
messageMap := make(map[uint]model.Message)
messageMap := make(map[uint]api.Message)
for i, message := range messages {
messageMap[messages[i].ID] = message
}
// Create a map to store replies by their parent ID
repliesMap := make(map[uint][]model.Message)
repliesMap := make(map[uint][]api.Message)
for i, message := range messages {
if messages[i].ParentID != nil {
repliesMap[*messages[i].ParentID] = append(repliesMap[*messages[i].ParentID], message)
@@ -326,7 +326,7 @@ func fetchMessages(db *gorm.DB) ([]model.Message, error) {
// Assign replies, parent, and selected reply to each message
for i := range messages {
if replies, exists := repliesMap[messages[i].ID]; exists {
messages[i].Replies = make([]model.Message, len(replies))
messages[i].Replies = make([]api.Message, len(replies))
for j, m := range replies {
messages[i].Replies[j] = m
}
@@ -345,21 +345,21 @@ func fetchMessages(db *gorm.DB) ([]model.Message, error) {
return messages, nil
}
func (s *SQLStore) buildPath(message *model.Message, getNext func(*model.Message) *uint) ([]model.Message, error) {
var messages []model.Message
func (s *SQLStore) buildPath(message *api.Message, getNext func(*api.Message) *uint) ([]api.Message, error) {
var messages []api.Message
messages, err := fetchMessages(s.db.Where("conversation_id = ?", message.ConversationID))
if err != nil {
return nil, err
}
// Create a map to store messages by their ID
messageMap := make(map[uint]*model.Message)
messageMap := make(map[uint]*api.Message)
for i := range messages {
messageMap[messages[i].ID] = &messages[i]
}
// Build the path
var path []model.Message
var path []api.Message
nextID := &message.ID
for {
@@ -382,12 +382,12 @@ func (s *SQLStore) buildPath(message *model.Message, getNext func(*model.Message
// PathToRoot traverses the provided message's Parent until reaching the tree
// root and returns a slice of all messages traversed in chronological order
// (starting with the root and ending with the message provided)
func (s *SQLStore) PathToRoot(message *model.Message) ([]model.Message, error) {
func (s *SQLStore) PathToRoot(message *api.Message) ([]api.Message, error) {
if message == nil || message.ID <= 0 {
return nil, fmt.Errorf("Message is nil or has invalid ID")
}
path, err := s.buildPath(message, func(m *model.Message) *uint {
path, err := s.buildPath(message, func(m *api.Message) *uint {
return m.ParentID
})
if err != nil {
@@ -401,24 +401,24 @@ func (s *SQLStore) PathToRoot(message *model.Message) ([]model.Message, error) {
// PathToLeaf traverses the provided message's SelectedReply until reaching a
// tree leaf and returns a slice of all messages traversed in chronological
// order (starting with the message provided and ending with the leaf)
func (s *SQLStore) PathToLeaf(message *model.Message) ([]model.Message, error) {
func (s *SQLStore) PathToLeaf(message *api.Message) ([]api.Message, error) {
if message == nil || message.ID <= 0 {
return nil, fmt.Errorf("Message is nil or has invalid ID")
}
return s.buildPath(message, func(m *model.Message) *uint {
return s.buildPath(message, func(m *api.Message) *uint {
return m.SelectedReplyID
})
}
func (s *SQLStore) LatestConversationMessages() ([]model.Message, error) {
var latestMessages []model.Message
func (s *SQLStore) LatestConversationMessages() ([]api.Message, error) {
var latestMessages []api.Message
subQuery := s.db.Model(&model.Message{}).
subQuery := s.db.Model(&api.Message{}).
Select("MAX(created_at) as max_created_at, conversation_id").
Group("conversation_id")
err := s.db.Model(&model.Message{}).
err := s.db.Model(&api.Message{}).
Joins("JOIN (?) as sub on messages.conversation_id = sub.conversation_id AND messages.created_at = sub.max_created_at", subQuery).
Group("messages.conversation_id").
Order("created_at DESC").

View File

@@ -1,142 +0,0 @@
package tools
import (
"fmt"
"os"
"path/filepath"
"strconv"
"strings"
"git.mlow.ca/mlow/lmcli/pkg/lmcli/model"
toolutil "git.mlow.ca/mlow/lmcli/pkg/lmcli/tools/util"
)
const TREE_DESCRIPTION = `Retrieve a tree-like view of a directory's contents.
Use these results for your own reference in completing your task, they do not need to be shown to the user.
Example result:
{
"message": "success",
"result": ".
├── a_directory/
│ ├── file1.txt (100 bytes)
│ └── file2.txt (200 bytes)
├── a_file.txt (123 bytes)
└── another_file.txt (456 bytes)"
}
`
var DirTreeTool = model.Tool{
Name: "dir_tree",
Description: TREE_DESCRIPTION,
Parameters: []model.ToolParameter{
{
Name: "relative_path",
Type: "string",
Description: "If set, display the tree starting from this path relative to the current one.",
},
{
Name: "depth",
Type: "integer",
Description: "Depth of directory recursion. Defaults to 0 (no recursion), maximum of 5.",
},
},
Impl: func(tool *model.Tool, args map[string]interface{}) (string, error) {
var relativeDir string
if tmp, ok := args["relative_path"]; ok {
relativeDir, ok = tmp.(string)
if !ok {
return "", fmt.Errorf("expected string for relative_path, got %T", tmp)
}
}
var depth int = 0 // Default value if not provided
if tmp, ok := args["depth"]; ok {
switch v := tmp.(type) {
case float64:
depth = int(v)
case string:
var err error
if depth, err = strconv.Atoi(v); err != nil {
return "", fmt.Errorf("invalid `depth` value, expected integer but got string that cannot convert: %v", tmp)
}
depth = max(0, min(5, depth))
default:
return "", fmt.Errorf("expected int or string for max_depth, got %T", tmp)
}
}
result := tree(relativeDir, depth)
ret, err := result.ToJson()
if err != nil {
return "", fmt.Errorf("could not serialize result: %v", err)
}
return ret, nil
},
}
func tree(path string, depth int) model.CallResult {
if path == "" {
path = "."
}
ok, reason := toolutil.IsPathWithinCWD(path)
if !ok {
return model.CallResult{Message: reason}
}
var treeOutput strings.Builder
treeOutput.WriteString(path + "\n")
err := buildTree(&treeOutput, path, "", depth)
if err != nil {
return model.CallResult{
Message: err.Error(),
}
}
return model.CallResult{Result: treeOutput.String()}
}
func buildTree(output *strings.Builder, path string, prefix string, depth int) error {
files, err := os.ReadDir(path)
if err != nil {
return err
}
for i, file := range files {
if strings.HasPrefix(file.Name(), ".") {
// Skip hidden files and directories
continue
}
isLast := i == len(files)-1
var branch string
if isLast {
branch = "└── "
} else {
branch = "├── "
}
info, _ := file.Info()
size := info.Size()
sizeStr := fmt.Sprintf(" (%d bytes)", size)
output.WriteString(prefix + branch + file.Name())
if file.IsDir() {
output.WriteString("/\n")
if depth > 0 {
var nextPrefix string
if isLast {
nextPrefix = prefix + " "
} else {
nextPrefix = prefix + "│ "
}
buildTree(output, filepath.Join(path, file.Name()), nextPrefix, depth-1)
}
} else {
output.WriteString(sizeStr + "\n")
}
}
return nil
}

View File

@@ -1,114 +0,0 @@
package tools
import (
"fmt"
"os"
"strings"
toolutil "git.mlow.ca/mlow/lmcli/pkg/lmcli/tools/util"
"git.mlow.ca/mlow/lmcli/pkg/lmcli/model"
)
const FILE_INSERT_LINES_DESCRIPTION = `Insert lines into a file, must specify path.
Make sure your inserts match the flow and indentation of surrounding content.`
var FileInsertLinesTool = model.Tool{
Name: "file_insert_lines",
Description: FILE_INSERT_LINES_DESCRIPTION,
Parameters: []model.ToolParameter{
{
Name: "path",
Type: "string",
Description: "Path of the file to be modified, relative to the current working directory.",
Required: true,
},
{
Name: "position",
Type: "integer",
Description: `Which line to insert content *before*.`,
Required: true,
},
{
Name: "content",
Type: "string",
Description: `The content to insert.`,
Required: true,
},
},
Impl: func(tool *model.Tool, args map[string]interface{}) (string, error) {
tmp, ok := args["path"]
if !ok {
return "", fmt.Errorf("path parameter to write_file was not included.")
}
path, ok := tmp.(string)
if !ok {
return "", fmt.Errorf("Invalid path in function arguments: %v", tmp)
}
var position int
tmp, ok = args["position"]
if ok {
tmp, ok := tmp.(float64)
if !ok {
return "", fmt.Errorf("Invalid position in function arguments: %v", tmp)
}
position = int(tmp)
}
var content string
tmp, ok = args["content"]
if ok {
content, ok = tmp.(string)
if !ok {
return "", fmt.Errorf("Invalid content in function arguments: %v", tmp)
}
}
result := fileInsertLines(path, position, content)
ret, err := result.ToJson()
if err != nil {
return "", fmt.Errorf("Could not serialize result: %v", err)
}
return ret, nil
},
}
func fileInsertLines(path string, position int, content string) model.CallResult {
ok, reason := toolutil.IsPathWithinCWD(path)
if !ok {
return model.CallResult{Message: reason}
}
// Read the existing file's content
data, err := os.ReadFile(path)
if err != nil {
if !os.IsNotExist(err) {
return model.CallResult{Message: fmt.Sprintf("Could not read path: %s", err.Error())}
}
_, err = os.Create(path)
if err != nil {
return model.CallResult{Message: fmt.Sprintf("Could not create new file: %s", err.Error())}
}
data = []byte{}
}
if position < 1 {
return model.CallResult{Message: "start_line cannot be less than 1"}
}
lines := strings.Split(string(data), "\n")
contentLines := strings.Split(strings.Trim(content, "\n"), "\n")
before := lines[:position-1]
after := lines[position-1:]
lines = append(before, append(contentLines, after...)...)
newContent := strings.Join(lines, "\n")
// Join the lines and write back to the file
err = os.WriteFile(path, []byte(newContent), 0644)
if err != nil {
return model.CallResult{Message: fmt.Sprintf("Could not write to path: %s", err.Error())}
}
return model.CallResult{Result: newContent}
}

View File

@@ -1,133 +0,0 @@
package tools
import (
"fmt"
"os"
"strings"
"git.mlow.ca/mlow/lmcli/pkg/lmcli/model"
toolutil "git.mlow.ca/mlow/lmcli/pkg/lmcli/tools/util"
)
const FILE_REPLACE_LINES_DESCRIPTION = `Replace or remove a range of lines within a file, must specify path.
Useful for re-writing snippets/blocks of code or entire functions.
Plan your edits carefully and ensure any new content matches the flow and indentation of surrounding text.`
var FileReplaceLinesTool = model.Tool{
Name: "file_replace_lines",
Description: FILE_REPLACE_LINES_DESCRIPTION,
Parameters: []model.ToolParameter{
{
Name: "path",
Type: "string",
Description: "Path of the file to be modified, relative to the current working directory.",
Required: true,
},
{
Name: "start_line",
Type: "integer",
Description: `Line number which specifies the start of the replacement range (inclusive).`,
Required: true,
},
{
Name: "end_line",
Type: "integer",
Description: `Line number which specifies the end of the replacement range (inclusive). If unset, range extends to end of file.`,
},
{
Name: "content",
Type: "string",
Description: `Content to replace specified range. Omit to remove the specified range.`,
},
},
Impl: func(tool *model.Tool, args map[string]interface{}) (string, error) {
tmp, ok := args["path"]
if !ok {
return "", fmt.Errorf("path parameter to write_file was not included.")
}
path, ok := tmp.(string)
if !ok {
return "", fmt.Errorf("Invalid path in function arguments: %v", tmp)
}
var start_line int
tmp, ok = args["start_line"]
if ok {
tmp, ok := tmp.(float64)
if !ok {
return "", fmt.Errorf("Invalid start_line in function arguments: %v", tmp)
}
start_line = int(tmp)
}
var end_line int
tmp, ok = args["end_line"]
if ok {
tmp, ok := tmp.(float64)
if !ok {
return "", fmt.Errorf("Invalid end_line in function arguments: %v", tmp)
}
end_line = int(tmp)
}
var content string
tmp, ok = args["content"]
if ok {
content, ok = tmp.(string)
if !ok {
return "", fmt.Errorf("Invalid content in function arguments: %v", tmp)
}
}
result := fileReplaceLines(path, start_line, end_line, content)
ret, err := result.ToJson()
if err != nil {
return "", fmt.Errorf("Could not serialize result: %v", err)
}
return ret, nil
},
}
func fileReplaceLines(path string, startLine int, endLine int, content string) model.CallResult {
ok, reason := toolutil.IsPathWithinCWD(path)
if !ok {
return model.CallResult{Message: reason}
}
// Read the existing file's content
data, err := os.ReadFile(path)
if err != nil {
if !os.IsNotExist(err) {
return model.CallResult{Message: fmt.Sprintf("Could not read path: %s", err.Error())}
}
_, err = os.Create(path)
if err != nil {
return model.CallResult{Message: fmt.Sprintf("Could not create new file: %s", err.Error())}
}
data = []byte{}
}
if startLine < 1 {
return model.CallResult{Message: "start_line cannot be less than 1"}
}
lines := strings.Split(string(data), "\n")
contentLines := strings.Split(strings.Trim(content, "\n"), "\n")
if endLine == 0 || endLine > len(lines) {
endLine = len(lines)
}
before := lines[:startLine-1]
after := lines[endLine:]
lines = append(before, append(contentLines, after...)...)
newContent := strings.Join(lines, "\n")
// Join the lines and write back to the file
err = os.WriteFile(path, []byte(newContent), 0644)
if err != nil {
return model.CallResult{Message: fmt.Sprintf("Could not write to path: %s", err.Error())}
}
return model.CallResult{Result: newContent}
}

View File

@@ -1,100 +0,0 @@
package tools
import (
"fmt"
"os"
"path/filepath"
"strings"
"git.mlow.ca/mlow/lmcli/pkg/lmcli/model"
toolutil "git.mlow.ca/mlow/lmcli/pkg/lmcli/tools/util"
)
const READ_DIR_DESCRIPTION = `Return the contents of the CWD (current working directory).
Example result:
{
"message": "success",
"result": [
{"name": "a_file.txt", "type": "file", "size": 123},
{"name": "a_directory/", "type": "dir", "size": 11},
...
]
}
For files, size represents the size of the file, in bytes.
For directories, size represents the number of entries in that directory.`
var ReadDirTool = model.Tool{
Name: "read_dir",
Description: READ_DIR_DESCRIPTION,
Parameters: []model.ToolParameter{
{
Name: "relative_dir",
Type: "string",
Description: "If set, read the contents of a directory relative to the current one.",
},
},
Impl: func(tool *model.Tool, args map[string]interface{}) (string, error) {
var relativeDir string
tmp, ok := args["relative_dir"]
if ok {
relativeDir, ok = tmp.(string)
if !ok {
return "", fmt.Errorf("Invalid relative_dir in function arguments: %v", tmp)
}
}
result := readDir(relativeDir)
ret, err := result.ToJson()
if err != nil {
return "", fmt.Errorf("Could not serialize result: %v", err)
}
return ret, nil
},
}
func readDir(path string) model.CallResult {
if path == "" {
path = "."
}
ok, reason := toolutil.IsPathWithinCWD(path)
if !ok {
return model.CallResult{Message: reason}
}
files, err := os.ReadDir(path)
if err != nil {
return model.CallResult{
Message: err.Error(),
}
}
var dirContents []map[string]interface{}
for _, f := range files {
info, _ := f.Info()
name := f.Name()
if strings.HasPrefix(name, ".") {
// skip hidden files
continue
}
entryType := "file"
size := info.Size()
if info.IsDir() {
name += "/"
entryType = "dir"
subdirfiles, _ := os.ReadDir(filepath.Join(".", path, info.Name()))
size = int64(len(subdirfiles))
}
dirContents = append(dirContents, map[string]interface{}{
"name": name,
"type": entryType,
"size": size,
})
}
return model.CallResult{Result: dirContents}
}

View File

@@ -1,73 +0,0 @@
package tools
import (
"fmt"
"os"
"strings"
"git.mlow.ca/mlow/lmcli/pkg/lmcli/model"
toolutil "git.mlow.ca/mlow/lmcli/pkg/lmcli/tools/util"
)
const READ_FILE_DESCRIPTION = `Retrieve the contents of a text file relative to the current working directory.
Use the file contents for your own reference in completing your task, they do not need to be shown to the user.
Each line of the returned content is prefixed with its line number and a tab (\t).
Example result:
{
"message": "success",
"result": "1\tthe contents\n2\tof the file\n"
}`
var ReadFileTool = model.Tool{
Name: "read_file",
Description: READ_FILE_DESCRIPTION,
Parameters: []model.ToolParameter{
{
Name: "path",
Type: "string",
Description: "Path to a file within the current working directory to read.",
Required: true,
},
},
Impl: func(tool *model.Tool, args map[string]interface{}) (string, error) {
tmp, ok := args["path"]
if !ok {
return "", fmt.Errorf("Path parameter to read_file was not included.")
}
path, ok := tmp.(string)
if !ok {
return "", fmt.Errorf("Invalid path in function arguments: %v", tmp)
}
result := readFile(path)
ret, err := result.ToJson()
if err != nil {
return "", fmt.Errorf("Could not serialize result: %v", err)
}
return ret, nil
},
}
func readFile(path string) model.CallResult {
ok, reason := toolutil.IsPathWithinCWD(path)
if !ok {
return model.CallResult{Message: reason}
}
data, err := os.ReadFile(path)
if err != nil {
return model.CallResult{Message: fmt.Sprintf("Could not read path: %s", err.Error())}
}
lines := strings.Split(string(data), "\n")
content := strings.Builder{}
for i, line := range lines {
content.WriteString(fmt.Sprintf("%d\t%s\n", i+1, line))
}
return model.CallResult{
Result: content.String(),
}
}

View File

@@ -1,48 +0,0 @@
package tools
import (
"fmt"
"git.mlow.ca/mlow/lmcli/pkg/lmcli/model"
)
var AvailableTools map[string]model.Tool = map[string]model.Tool{
"dir_tree": DirTreeTool,
"read_dir": ReadDirTool,
"read_file": ReadFileTool,
"write_file": WriteFileTool,
"file_insert_lines": FileInsertLinesTool,
"file_replace_lines": FileReplaceLinesTool,
}
func ExecuteToolCalls(toolCalls []model.ToolCall, toolBag []model.Tool) ([]model.ToolResult, error) {
var toolResults []model.ToolResult
for _, toolCall := range toolCalls {
var tool *model.Tool
for _, available := range toolBag {
if available.Name == toolCall.Name {
tool = &available
break
}
}
if tool == nil {
return nil, fmt.Errorf("Requested tool '%s' does not exist. Hallucination?", toolCall.Name)
}
// Execute the tool
result, err := tool.Impl(tool, toolCall.Parameters)
if err != nil {
// This can happen if the model missed or supplied invalid tool args
return nil, fmt.Errorf("Tool '%s' error: %v\n", toolCall.Name, err)
}
toolResult := model.ToolResult{
ToolCallID: toolCall.ID,
ToolName: toolCall.Name,
Result: result,
}
toolResults = append(toolResults, toolResult)
}
return toolResults, nil
}

View File

@@ -1,67 +0,0 @@
package util
import (
"fmt"
"os"
"path/filepath"
"strings"
)
// isPathContained attempts to verify whether `path` is the same as or
// contained within `directory`. It is overly cautious, returning false even if
// `path` IS contained within `directory`, but the two paths use different
// casing, and we happen to be on a case-insensitive filesystem.
// This is ultimately to attempt to stop an LLM from going outside of where I
// tell it to. Additional layers of security should be considered.. run in a
// VM/container.
func IsPathContained(directory string, path string) (bool, error) {
// Clean and resolve symlinks for both paths
path, err := filepath.Abs(path)
if err != nil {
return false, err
}
// check if path exists
_, err = os.Stat(path)
if err != nil {
if !os.IsNotExist(err) {
return false, fmt.Errorf("Could not stat path: %v", err)
}
} else {
path, err = filepath.EvalSymlinks(path)
if err != nil {
return false, err
}
}
directory, err = filepath.Abs(directory)
if err != nil {
return false, err
}
directory, err = filepath.EvalSymlinks(directory)
if err != nil {
return false, err
}
// Case insensitive checks
if !strings.EqualFold(path, directory) &&
!strings.HasPrefix(strings.ToLower(path), strings.ToLower(directory)+string(os.PathSeparator)) {
return false, nil
}
return true, nil
}
func IsPathWithinCWD(path string) (bool, string) {
cwd, err := os.Getwd()
if err != nil {
return false, "Failed to determine current working directory"
}
if ok, err := IsPathContained(cwd, path); !ok {
if err != nil {
return false, fmt.Sprintf("Could not determine whether path '%s' is within the current working directory: %s", path, err.Error())
}
return false, fmt.Sprintf("Path '%s' is not within the current working directory", path)
}
return true, ""
}

View File

@@ -1,71 +0,0 @@
package tools
import (
"fmt"
"os"
"git.mlow.ca/mlow/lmcli/pkg/lmcli/model"
toolutil "git.mlow.ca/mlow/lmcli/pkg/lmcli/tools/util"
)
const WRITE_FILE_DESCRIPTION = `Write the provided contents to a file relative to the current working directory.
Example result:
{
"message": "success"
}`
var WriteFileTool = model.Tool{
Name: "write_file",
Description: WRITE_FILE_DESCRIPTION,
Parameters: []model.ToolParameter{
{
Name: "path",
Type: "string",
Description: "Path to a file within the current working directory to write to.",
Required: true,
},
{
Name: "content",
Type: "string",
Description: "The content to write to the file. Overwrites any existing content!",
Required: true,
},
},
Impl: func(t *model.Tool, args map[string]interface{}) (string, error) {
tmp, ok := args["path"]
if !ok {
return "", fmt.Errorf("Path parameter to write_file was not included.")
}
path, ok := tmp.(string)
if !ok {
return "", fmt.Errorf("Invalid path in function arguments: %v", tmp)
}
tmp, ok = args["content"]
if !ok {
return "", fmt.Errorf("Content parameter to write_file was not included.")
}
content, ok := tmp.(string)
if !ok {
return "", fmt.Errorf("Invalid content in function arguments: %v", tmp)
}
result := writeFile(path, content)
ret, err := result.ToJson()
if err != nil {
return "", fmt.Errorf("Could not serialize result: %v", err)
}
return ret, nil
},
}
func writeFile(path string, content string) model.CallResult {
ok, reason := toolutil.IsPathWithinCWD(path)
if !ok {
return model.CallResult{Message: reason}
}
err := os.WriteFile(path, []byte(content), 0644)
if err != nil {
return model.CallResult{Message: fmt.Sprintf("Could not write to path: %s", err.Error())}
}
return model.CallResult{}
}