lmcli/pkg/cli/functions.go
2023-11-26 10:51:06 +00:00

519 lines
15 KiB
Go

package cli
import (
"database/sql"
"encoding/json"
"fmt"
"os"
"path/filepath"
"strings"
openai "github.com/sashabaranov/go-openai"
)
type FunctionResult struct {
Message string `json:"message"`
Result any `json:"result,omitempty"`
}
type FunctionParameter struct {
Type string `json:"type"` // "string", "integer", "boolean"
Description string `json:"description"`
Enum []string `json:"enum,omitempty"`
}
type FunctionParameters struct {
Type string `json:"type"` // "object"
Properties map[string]FunctionParameter `json:"properties"`
Required []string `json:"required,omitempty"` // required function parameter names
}
type AvailableTool struct {
openai.Tool
// The tool's implementation. Returns a string, as tool call results
// are treated as normal messages with string contents.
Impl func(arguments map[string]interface{}) (string, error)
}
var AvailableTools = map[string]AvailableTool{
"read_dir": {
Tool: openai.Tool{Type: "function", Function: openai.FunctionDefinition{
Name: "read_dir",
Description: `Return the contents of the CWD (current working directory).
Results are returned as JSON in the following format:
{
"message": "success", // if successful, or a different message indicating failure
// result may be an empty array if there are no files in the directory
"result": [
{"name": "a_file", "type": "file", "size": 123},
{"name": "a_directory/", "type": "dir", "size": 11},
... // more files or directories
]
}
For files, size represents the size (in bytes) of the file.
For directories, size represents the number of entries in that directory.`,
Parameters: FunctionParameters{
Type: "object",
Properties: map[string]FunctionParameter{
"relative_dir": {
Type: "string",
Description: "If set, read the contents of a directory relative to the current one.",
},
},
},
}},
Impl: func(args map[string]interface{}) (string, error) {
var relativeDir string
tmp, ok := args["relative_dir"]
if ok {
relativeDir, ok = tmp.(string)
if !ok {
return "", fmt.Errorf("Invalid relative_dir in function arguments: %v", tmp)
}
}
return ReadDir(relativeDir), nil
},
},
"read_file": {
Tool: openai.Tool{Type: "function", Function: openai.FunctionDefinition{
Name: "read_file",
Description: `Read the contents of a text file relative to the current working directory.
Each line of the file is prefixed with its line number and a tabs (\t) to make
it make it easier to see which lines to change for other modifications.
Example:
{
"message": "success", // if successful, or a different message indicating failure
"result": "1\tthe contents\n2\tof the file\n"
}`,
Parameters: FunctionParameters{
Type: "object",
Properties: map[string]FunctionParameter{
"path": {
Type: "string",
Description: "Path to a file within the current working directory to read.",
},
},
Required: []string{"path"},
},
}},
Impl: func(args map[string]interface{}) (string, error) {
tmp, ok := args["path"]
if !ok {
return "", fmt.Errorf("Path parameter to read_file was not included.")
}
path, ok := tmp.(string)
if !ok {
return "", fmt.Errorf("Invalid path in function arguments: %v", tmp)
}
return ReadFile(path), nil
},
},
"write_file": {
Tool: openai.Tool{Type: "function", Function: openai.FunctionDefinition{
Name: "write_file",
Description: `Write the provided contents to a file relative to the current working directory.
Result is returned as JSON in the following format:
{
"message": "success", // if successful, or a different message indicating failure
}`,
Parameters: FunctionParameters{
Type: "object",
Properties: map[string]FunctionParameter{
"path": {
Type: "string",
Description: "Path to a file within the current working directory to write to.",
},
"content": {
Type: "string",
Description: "The content to write to the file. Overwrites any existing content!",
},
},
Required: []string{"path", "content"},
},
}},
Impl: func(args map[string]interface{}) (string, error) {
tmp, ok := args["path"]
if !ok {
return "", fmt.Errorf("Path parameter to write_file was not included.")
}
path, ok := tmp.(string)
if !ok {
return "", fmt.Errorf("Invalid path in function arguments: %v", tmp)
}
tmp, ok = args["content"]
if !ok {
return "", fmt.Errorf("Content parameter to write_file was not included.")
}
content, ok := tmp.(string)
if !ok {
return "", fmt.Errorf("Invalid content in function arguments: %v", tmp)
}
return WriteFile(path, content), nil
},
},
"modify_file": {
Tool: openai.Tool{Type: "function", Function: openai.FunctionDefinition{
Name: "modify_file",
Description: `Perform complex line-based modifications to a file.
Line ranges are inclusive. If 'start_line' is specified but 'end_line' is not,
'end_line' gets set to the last line of the file.
To replace or remove a single line, *set start_line and end_line to the same value*
Examples:
* Insert the lines "hello<new line>world" at line 10, preserving other content:
{"path": "myfile", "operation": "insert", "start_line": 10, "content": "hello\nworld"}
* Remove lines 45 up to and including 54:
{"path": "myfile", "operation": "remove", "start_line": 45, "end_line": 54}
* Replace content from line 10 to 25:
{"path": "myfile", "operation": "replace", "start_line": 10, "end_line": 25, "content": "i\nwas\nhere"}
* Replace contents of entire the file:
{"path": "myfile", "operation": "replace", "start_line": 0, "content": "i\nwas\nhere"}`,
Parameters: FunctionParameters{
Type: "object",
Properties: map[string]FunctionParameter{
"path": {
Type: "string",
Description: "Path of the file to be modified, relative to the current working directory.",
},
"operation": {
Type: "string",
Description: `The the type of modification to make to the file. One of: insert, remove, replace`,
},
"start_line": {
Type: "integer",
Description: `(Optional) Where to start making a modification (insert, remove, and replace).`,
},
"end_line": {
Type: "integer",
Description: `(Optional) Where to stop making a modification (remove or replace, end of file if omitted).`,
},
"content": {
Type: "string",
Description: `(Optional) The content to insert, or replace with.`,
},
},
Required: []string{"path", "operation"},
},
}},
Impl: func(args map[string]interface{}) (string, error) {
tmp, ok := args["path"]
if !ok {
return "", fmt.Errorf("path parameter to write_file was not included.")
}
path, ok := tmp.(string)
if !ok {
return "", fmt.Errorf("Invalid path in function arguments: %v", tmp)
}
tmp, ok = args["operation"]
if !ok {
return "", fmt.Errorf("operation parameter to modify_file was not included.")
}
operation, ok := tmp.(string)
if !ok {
return "", fmt.Errorf("Invalid operation in function arguments: %v", tmp)
}
var start_line int
tmp, ok = args["start_line"]
if ok {
tmp, ok := tmp.(float64)
if !ok {
return "", fmt.Errorf("Invalid start_line in function arguments: %v", tmp)
}
start_line = int(tmp)
}
var end_line int
tmp, ok = args["end_line"]
if ok {
tmp, ok := tmp.(float64)
if !ok {
return "", fmt.Errorf("Invalid end_line in function arguments: %v", tmp)
}
end_line = int(tmp)
}
var content string
tmp, ok = args["content"]
if ok {
content, ok = tmp.(string)
if !ok {
return "", fmt.Errorf("Invalid content in function arguments: %v", tmp)
}
}
return ModifyFile(path, operation, content, start_line, end_line), nil
},
},
}
func resultToJson(result FunctionResult) string {
if result.Message == "" {
// When message not supplied, assume success
result.Message = "success"
}
jsonBytes, err := json.Marshal(result)
if err != nil {
fmt.Printf("Could not marshal FunctionResult to JSON: %v\n", err)
}
return string(jsonBytes)
}
// ExecuteToolCalls handles the execution of all tool_calls provided, and
// returns their results formatted as []Message(s) with role: 'tool' and.
func ExecuteToolCalls(toolCalls []openai.ToolCall) ([]Message, error) {
var toolResults []Message
for _, toolCall := range toolCalls {
if toolCall.Type != "function" {
// unsupported tool type
continue
}
tool, ok := AvailableTools[toolCall.Function.Name]
if !ok {
return nil, fmt.Errorf("Requested tool '%s' does not exist. Hallucination?", toolCall.Function.Name)
}
var functionArgs map[string]interface{}
err := json.Unmarshal([]byte(toolCall.Function.Arguments), &functionArgs)
if err != nil {
return nil, fmt.Errorf("Could not unmarshal tool arguments. Malformed JSON? Error: %v", err)
}
// TODO: ability to silence this
fmt.Fprintf(os.Stderr, "INFO: Executing tool '%s' with args %s\n", toolCall.Function.Name, toolCall.Function.Arguments)
// Execute the tool
toolResult, err := tool.Impl(functionArgs)
if err != nil {
// This can happen if the model missed or supplied invalid tool args
return nil, fmt.Errorf("Tool '%s' error: %v\n", toolCall.Function.Name, err)
}
toolResults = append(toolResults, Message{
Role: "tool",
OriginalContent: toolResult,
ToolCallID: sql.NullString{String: toolCall.ID, Valid: true},
// name is not required since the introduction of ToolCallID
// hypothesis: by setting it, we inform the model of what a
// function's purpose was if future requests omit the function
// definition
})
}
return toolResults, nil
}
// isPathContained attempts to verify whether `path` is the same as or
// contained within `directory`. It is overly cautious, returning false even if
// `path` IS contained within `directory`, but the two paths use different
// casing, and we happen to be on a case-insensitive filesystem.
// This is ultimately to attempt to stop an LLM from going outside of where I
// tell it to. Additional layers of security should be considered.. run in a
// VM/container.
func isPathContained(directory string, path string) (bool, error) {
// Clean and resolve symlinks for both paths
path, err := filepath.Abs(path)
if err != nil {
return false, err
}
// check if path exists
_, err = os.Stat(path)
if err != nil {
if !os.IsNotExist(err) {
return false, fmt.Errorf("Could not stat path: %v", err)
}
} else {
path, err = filepath.EvalSymlinks(path)
if err != nil {
return false, err
}
}
directory, err = filepath.Abs(directory)
if err != nil {
return false, err
}
directory, err = filepath.EvalSymlinks(directory)
if err != nil {
return false, err
}
// Case insensitive checks
if !strings.EqualFold(path, directory) &&
!strings.HasPrefix(strings.ToLower(path), strings.ToLower(directory)+string(os.PathSeparator)) {
return false, nil
}
return true, nil
}
func isPathWithinCWD(path string) (bool, *FunctionResult) {
cwd, err := os.Getwd()
if err != nil {
return false, &FunctionResult{Message: "Failed to determine current working directory"}
}
if ok, err := isPathContained(cwd, path); !ok {
if err != nil {
return false, &FunctionResult{Message: fmt.Sprintf("Could not determine whether path '%s' is within the current working directory: %s", path, err.Error())}
}
return false, &FunctionResult{Message: fmt.Sprintf("Path '%s' is not within the current working directory", path)}
}
return true, nil
}
func ReadDir(path string) string {
// TODO(?): implement whitelist - list of directories which model is allowed to work in
if path == "" {
path = "."
}
ok, res := isPathWithinCWD(path)
if !ok {
return resultToJson(*res)
}
files, err := os.ReadDir(path)
if err != nil {
return resultToJson(FunctionResult{
Message: err.Error(),
})
}
var dirContents []map[string]interface{}
for _, f := range files {
info, _ := f.Info()
name := f.Name()
if strings.HasPrefix(name, ".") {
// skip hidden files
continue
}
entryType := "file"
size := info.Size()
if info.IsDir() {
name += "/"
entryType = "dir"
subdirfiles, _ := os.ReadDir(filepath.Join(".", path, info.Name()))
size = int64(len(subdirfiles))
}
dirContents = append(dirContents, map[string]interface{}{
"name": name,
"type": entryType,
"size": size,
})
}
return resultToJson(FunctionResult{Result: dirContents})
}
func ReadFile(path string) string {
ok, res := isPathWithinCWD(path)
if !ok {
return resultToJson(*res)
}
data, err := os.ReadFile(path)
if err != nil {
return resultToJson(FunctionResult{Message: fmt.Sprintf("Could not read path: %s", err.Error())})
}
lines := strings.Split(string(data), "\n")
content := strings.Builder{}
for i, line := range lines {
content.WriteString(fmt.Sprintf("%d\t%s\n", i+1, line))
}
return resultToJson(FunctionResult{
Result: content.String(),
})
}
func WriteFile(path string, content string) string {
ok, res := isPathWithinCWD(path)
if !ok {
return resultToJson(*res)
}
err := os.WriteFile(path, []byte(content), 0644)
if err != nil {
return resultToJson(FunctionResult{Message: fmt.Sprintf("Could not write to path: %s", err.Error())})
}
return resultToJson(FunctionResult{})
}
func ModifyFile(path string, operation string, content string, startLine int, endLine int) string {
ok, res := isPathWithinCWD(path)
if !ok {
return resultToJson(*res)
}
// Read the existing file's content
data, err := os.ReadFile(path)
if err != nil {
if !os.IsNotExist(err) {
return resultToJson(FunctionResult{Message: fmt.Sprintf("Could not read path: %s", err.Error())})
}
_, err = os.Create(path)
if err != nil {
return resultToJson(FunctionResult{Message: fmt.Sprintf("Could not create new file: %s", err.Error())})
}
data = []byte{}
}
if startLine < 0 {
return resultToJson(FunctionResult{Message: "start_line cannot be less than 0"})
}
// Split the content by newline to process lines
lines := strings.Split(string(data), "\n")
contentLines := strings.Split(strings.Trim(content, "\n"), "\n")
switch operation {
case "insert":
// Insert new lines
before := lines[:startLine-1]
after := append(contentLines, lines[startLine-1:]...)
lines = append(before, after...)
case "remove":
// Remove lines
if endLine == 0 || endLine > len(lines) {
endLine = len(lines)
}
lines = append(lines[:startLine-1], lines[endLine:]...)
case "replace":
// Replace the lines between start_line and end_line
if endLine == 0 || endLine > len(lines) {
endLine = len(lines)
}
if startLine == 0 {
// model likely trying to replace contents, must start at line 1
startLine = 1
}
before := lines[:startLine-1]
after := lines[endLine:]
lines = append(before, append(contentLines, after...)...)
default:
return resultToJson(FunctionResult{Message: fmt.Sprintf("Invalid operation: %s", operation)})
}
newContent := strings.Join(lines, "\n")
// Join the lines and write back to the file
err = os.WriteFile(path, []byte(newContent), 0644)
if err != nil {
return resultToJson(FunctionResult{Message: fmt.Sprintf("Could not write to path: %s", err.Error())})
}
return resultToJson(FunctionResult{Result: newContent})
}