diff --git a/pkg/cli/functions.go b/pkg/cli/functions.go new file mode 100644 index 0000000..ea8a9c3 --- /dev/null +++ b/pkg/cli/functions.go @@ -0,0 +1,582 @@ +package cli + +import ( + "database/sql" + "encoding/json" + "fmt" + "os" + "path/filepath" + "strings" + + openai "github.com/sashabaranov/go-openai" +) + +type FunctionResult struct { + Message string `json:"message"` + Result any `json:"result,omitempty"` +} + +type FunctionParameter struct { + Type string `json:"type"` // "string", "integer", "boolean" + Description string `json:"description"` + Enum []string `json:"enum,omitempty"` +} + +type FunctionParameters struct { + Type string `json:"type"` // "object" + Properties map[string]FunctionParameter `json:"properties"` + Required []string `json:"required,omitempty"` // required function parameter names +} + +type AvailableTool struct { + openai.Tool + // The tool's implementation. Returns a string, as tool call results + // are treated as normal messages with string contents. + Impl func(arguments map[string]interface{}) (string, error) +} + +const ( + READ_DIR_DESCRIPTION = `Return the contents of the CWD (current working directory). + +Results are returned as JSON in the following format: +{ + "message": "success", // if successful, or a different message indicating failure + // result may be an empty array if there are no files in the directory + "result": [ + {"name": "a_file", "type": "file", "size": 123}, + {"name": "a_directory/", "type": "dir", "size": 11}, + ... // more files or directories + ] +} + +For files, size represents the size (in bytes) of the file. +For directories, size represents the number of entries in that directory.` + + READ_FILE_DESCRIPTION = `Read the contents of a text file relative to the current working directory. + +Each line of the file is prefixed with its line number and a tabs (\t) to make +it make it easier to see which lines to change for other modifications. + +Example result: +{ + "message": "success", // if successful, or a different message indicating failure + "result": "1\tthe contents\n2\tof the file\n" +}` + + WRITE_FILE_DESCRIPTION = `Write the provided contents to a file relative to the current working directory. + +Note: only use this tool when you've been explicitly asked to create or write to a file. + +When using this function, you do not need to share the content you intend to write with the user first. + +Example result: +{ + "message": "success", // if successful, or a different message indicating failure +}` + + FILE_INSERT_LINES_DESCRIPTION = `Insert lines into a file, must specify path. + +Make sure your inserts match the flow and indentation of surrounding content.` + + FILE_REPLACE_LINES_DESCRIPTION = `Replace or remove a range of lines within a file, must specify path. + +Useful for re-writing snippets/blocks of code or entire functions. + +Be cautious with your edits. When replacing, ensure the replacement content matches the flow and indentation of surrounding content.` +) + +var AvailableTools = map[string]AvailableTool{ + "read_dir": { + Tool: openai.Tool{Type: "function", Function: openai.FunctionDefinition{ + Name: "read_dir", + Description: READ_DIR_DESCRIPTION, + Parameters: FunctionParameters{ + Type: "object", + Properties: map[string]FunctionParameter{ + "relative_dir": { + Type: "string", + Description: "If set, read the contents of a directory relative to the current one.", + }, + }, + }, + }}, + Impl: func(args map[string]interface{}) (string, error) { + var relativeDir string + tmp, ok := args["relative_dir"] + if ok { + relativeDir, ok = tmp.(string) + if !ok { + return "", fmt.Errorf("Invalid relative_dir in function arguments: %v", tmp) + } + } + return ReadDir(relativeDir), nil + }, + }, + "read_file": { + Tool: openai.Tool{Type: "function", Function: openai.FunctionDefinition{ + Name: "read_file", + Description: READ_FILE_DESCRIPTION, + Parameters: FunctionParameters{ + Type: "object", + Properties: map[string]FunctionParameter{ + "path": { + Type: "string", + Description: "Path to a file within the current working directory to read.", + }, + }, + Required: []string{"path"}, + }, + }}, + Impl: func(args map[string]interface{}) (string, error) { + tmp, ok := args["path"] + if !ok { + return "", fmt.Errorf("Path parameter to read_file was not included.") + } + path, ok := tmp.(string) + if !ok { + return "", fmt.Errorf("Invalid path in function arguments: %v", tmp) + } + return ReadFile(path), nil + }, + }, + "write_file": { + Tool: openai.Tool{Type: "function", Function: openai.FunctionDefinition{ + Name: "write_file", + Description: WRITE_FILE_DESCRIPTION, + Parameters: FunctionParameters{ + Type: "object", + Properties: map[string]FunctionParameter{ + "path": { + Type: "string", + Description: "Path to a file within the current working directory to write to.", + }, + "content": { + Type: "string", + Description: "The content to write to the file. Overwrites any existing content!", + }, + }, + Required: []string{"path", "content"}, + }, + }}, + Impl: func(args map[string]interface{}) (string, error) { + tmp, ok := args["path"] + if !ok { + return "", fmt.Errorf("Path parameter to write_file was not included.") + } + path, ok := tmp.(string) + if !ok { + return "", fmt.Errorf("Invalid path in function arguments: %v", tmp) + } + tmp, ok = args["content"] + if !ok { + return "", fmt.Errorf("Content parameter to write_file was not included.") + } + content, ok := tmp.(string) + if !ok { + return "", fmt.Errorf("Invalid content in function arguments: %v", tmp) + } + return WriteFile(path, content), nil + }, + }, + "file_insert_lines": { + Tool: openai.Tool{Type: "function", Function: openai.FunctionDefinition{ + Name: "file_insert_lines", + Description: FILE_INSERT_LINES_DESCRIPTION, + Parameters: FunctionParameters{ + Type: "object", + Properties: map[string]FunctionParameter{ + "path": { + Type: "string", + Description: "Path of the file to be modified, relative to the current working directory.", + }, + "position": { + Type: "integer", + Description: `Which line to insert content *before*.`, + }, + "content": { + Type: "string", + Description: `The content to insert.`, + }, + }, + Required: []string{"path", "position", "content"}, + }, + }}, + Impl: func(args map[string]interface{}) (string, error) { + tmp, ok := args["path"] + if !ok { + return "", fmt.Errorf("path parameter to write_file was not included.") + } + path, ok := tmp.(string) + if !ok { + return "", fmt.Errorf("Invalid path in function arguments: %v", tmp) + } + var position int + tmp, ok = args["position"] + if ok { + tmp, ok := tmp.(float64) + if !ok { + return "", fmt.Errorf("Invalid position in function arguments: %v", tmp) + } + position = int(tmp) + } + var content string + tmp, ok = args["content"] + if ok { + content, ok = tmp.(string) + if !ok { + return "", fmt.Errorf("Invalid content in function arguments: %v", tmp) + } + } + return FileInsertLines(path, position, content), nil + }, + }, + "file_replace_lines": { + Tool: openai.Tool{Type: "function", Function: openai.FunctionDefinition{ + Name: "file_replace_lines", + Description: FILE_REPLACE_LINES_DESCRIPTION, + Parameters: FunctionParameters{ + Type: "object", + Properties: map[string]FunctionParameter{ + "path": { + Type: "string", + Description: "Path of the file to be modified, relative to the current working directory.", + }, + "start_line": { + Type: "integer", + Description: `Line number which specifies the start of the replacement range (inclusive).`, + }, + "end_line": { + Type: "integer", + Description: `Line number which specifies the end of the replacement range (inclusive). If unset, range extends to end of file.`, + }, + "content": { + Type: "string", + Description: `Content to replace specified range. Omit to remove the specified range.`, + }, + }, + Required: []string{"path", "start_line"}, + }, + }}, + Impl: func(args map[string]interface{}) (string, error) { + tmp, ok := args["path"] + if !ok { + return "", fmt.Errorf("path parameter to write_file was not included.") + } + path, ok := tmp.(string) + if !ok { + return "", fmt.Errorf("Invalid path in function arguments: %v", tmp) + } + var start_line int + tmp, ok = args["start_line"] + if ok { + tmp, ok := tmp.(float64) + if !ok { + return "", fmt.Errorf("Invalid start_line in function arguments: %v", tmp) + } + start_line = int(tmp) + } + var end_line int + tmp, ok = args["end_line"] + if ok { + tmp, ok := tmp.(float64) + if !ok { + return "", fmt.Errorf("Invalid end_line in function arguments: %v", tmp) + } + end_line = int(tmp) + } + var content string + tmp, ok = args["content"] + if ok { + content, ok = tmp.(string) + if !ok { + return "", fmt.Errorf("Invalid content in function arguments: %v", tmp) + } + } + + return FileReplaceLines(path, start_line, end_line, content), nil + }, + }, +} + +func resultToJson(result FunctionResult) string { + if result.Message == "" { + // When message not supplied, assume success + result.Message = "success" + } + + jsonBytes, err := json.Marshal(result) + if err != nil { + fmt.Printf("Could not marshal FunctionResult to JSON: %v\n", err) + } + return string(jsonBytes) +} + +// ExecuteToolCalls handles the execution of all tool_calls provided, and +// returns their results formatted as []Message(s) with role: 'tool' and. +func ExecuteToolCalls(toolCalls []openai.ToolCall) ([]Message, error) { + var toolResults []Message + for _, toolCall := range toolCalls { + if toolCall.Type != "function" { + // unsupported tool type + continue + } + + tool, ok := AvailableTools[toolCall.Function.Name] + if !ok { + return nil, fmt.Errorf("Requested tool '%s' does not exist. Hallucination?", toolCall.Function.Name) + } + + var functionArgs map[string]interface{} + err := json.Unmarshal([]byte(toolCall.Function.Arguments), &functionArgs) + if err != nil { + return nil, fmt.Errorf("Could not unmarshal tool arguments. Malformed JSON? Error: %v", err) + } + + // TODO: ability to silence this + fmt.Fprintf(os.Stderr, "INFO: Executing tool '%s' with args %s\n", toolCall.Function.Name, toolCall.Function.Arguments) + + // Execute the tool + toolResult, err := tool.Impl(functionArgs) + if err != nil { + // This can happen if the model missed or supplied invalid tool args + return nil, fmt.Errorf("Tool '%s' error: %v\n", toolCall.Function.Name, err) + } + + toolResults = append(toolResults, Message{ + Role: "tool", + OriginalContent: toolResult, + ToolCallID: sql.NullString{String: toolCall.ID, Valid: true}, + // name is not required since the introduction of ToolCallID + // hypothesis: by setting it, we inform the model of what a + // function's purpose was if future requests omit the function + // definition + }) + } + return toolResults, nil +} + +// isPathContained attempts to verify whether `path` is the same as or +// contained within `directory`. It is overly cautious, returning false even if +// `path` IS contained within `directory`, but the two paths use different +// casing, and we happen to be on a case-insensitive filesystem. +// This is ultimately to attempt to stop an LLM from going outside of where I +// tell it to. Additional layers of security should be considered.. run in a +// VM/container. +func isPathContained(directory string, path string) (bool, error) { + // Clean and resolve symlinks for both paths + path, err := filepath.Abs(path) + if err != nil { + return false, err + } + + // check if path exists + _, err = os.Stat(path) + if err != nil { + if !os.IsNotExist(err) { + return false, fmt.Errorf("Could not stat path: %v", err) + } + } else { + path, err = filepath.EvalSymlinks(path) + if err != nil { + return false, err + } + } + + directory, err = filepath.Abs(directory) + if err != nil { + return false, err + } + directory, err = filepath.EvalSymlinks(directory) + if err != nil { + return false, err + } + + // Case insensitive checks + if !strings.EqualFold(path, directory) && + !strings.HasPrefix(strings.ToLower(path), strings.ToLower(directory)+string(os.PathSeparator)) { + return false, nil + } + + return true, nil +} + +func isPathWithinCWD(path string) (bool, *FunctionResult) { + cwd, err := os.Getwd() + if err != nil { + return false, &FunctionResult{Message: "Failed to determine current working directory"} + } + if ok, err := isPathContained(cwd, path); !ok { + if err != nil { + return false, &FunctionResult{Message: fmt.Sprintf("Could not determine whether path '%s' is within the current working directory: %s", path, err.Error())} + } + return false, &FunctionResult{Message: fmt.Sprintf("Path '%s' is not within the current working directory", path)} + } + return true, nil +} + +func ReadDir(path string) string { + // TODO(?): implement whitelist - list of directories which model is allowed to work in + if path == "" { + path = "." + } + ok, res := isPathWithinCWD(path) + if !ok { + return resultToJson(*res) + } + + files, err := os.ReadDir(path) + if err != nil { + return resultToJson(FunctionResult{ + Message: err.Error(), + }) + } + + var dirContents []map[string]interface{} + for _, f := range files { + info, _ := f.Info() + + name := f.Name() + if strings.HasPrefix(name, ".") { + // skip hidden files + continue + } + + entryType := "file" + size := info.Size() + + if info.IsDir() { + name += "/" + entryType = "dir" + subdirfiles, _ := os.ReadDir(filepath.Join(".", path, info.Name())) + size = int64(len(subdirfiles)) + } + + dirContents = append(dirContents, map[string]interface{}{ + "name": name, + "type": entryType, + "size": size, + }) + } + + return resultToJson(FunctionResult{Result: dirContents}) +} + +func ReadFile(path string) string { + ok, res := isPathWithinCWD(path) + if !ok { + return resultToJson(*res) + } + data, err := os.ReadFile(path) + if err != nil { + return resultToJson(FunctionResult{Message: fmt.Sprintf("Could not read path: %s", err.Error())}) + } + + lines := strings.Split(string(data), "\n") + content := strings.Builder{} + for i, line := range lines { + content.WriteString(fmt.Sprintf("%d\t%s\n", i+1, line)) + } + + return resultToJson(FunctionResult{ + Result: content.String(), + }) +} + +func WriteFile(path string, content string) string { + ok, res := isPathWithinCWD(path) + if !ok { + return resultToJson(*res) + } + err := os.WriteFile(path, []byte(content), 0644) + if err != nil { + return resultToJson(FunctionResult{Message: fmt.Sprintf("Could not write to path: %s", err.Error())}) + } + return resultToJson(FunctionResult{}) +} + +func FileInsertLines(path string, position int, content string) string { + ok, res := isPathWithinCWD(path) + if !ok { + return resultToJson(*res) + } + + // Read the existing file's content + data, err := os.ReadFile(path) + if err != nil { + if !os.IsNotExist(err) { + return resultToJson(FunctionResult{Message: fmt.Sprintf("Could not read path: %s", err.Error())}) + } + _, err = os.Create(path) + if err != nil { + return resultToJson(FunctionResult{Message: fmt.Sprintf("Could not create new file: %s", err.Error())}) + } + data = []byte{} + } + + if position < 1 { + return resultToJson(FunctionResult{Message: "start_line cannot be less than 1"}) + } + + lines := strings.Split(string(data), "\n") + contentLines := strings.Split(strings.Trim(content, "\n"), "\n") + + before := lines[:position-1] + after := lines[position-1:] + lines = append(before, append(contentLines, after...)...) + + newContent := strings.Join(lines, "\n") + + // Join the lines and write back to the file + err = os.WriteFile(path, []byte(newContent), 0644) + if err != nil { + return resultToJson(FunctionResult{Message: fmt.Sprintf("Could not write to path: %s", err.Error())}) + } + + return resultToJson(FunctionResult{Result: newContent}) +} + +func FileReplaceLines(path string, startLine int, endLine int, content string) string { + ok, res := isPathWithinCWD(path) + if !ok { + return resultToJson(*res) + } + + // Read the existing file's content + data, err := os.ReadFile(path) + if err != nil { + if !os.IsNotExist(err) { + return resultToJson(FunctionResult{Message: fmt.Sprintf("Could not read path: %s", err.Error())}) + } + _, err = os.Create(path) + if err != nil { + return resultToJson(FunctionResult{Message: fmt.Sprintf("Could not create new file: %s", err.Error())}) + } + data = []byte{} + } + + if startLine < 1 { + return resultToJson(FunctionResult{Message: "start_line cannot be less than 1"}) + } + + lines := strings.Split(string(data), "\n") + contentLines := strings.Split(strings.Trim(content, "\n"), "\n") + + if endLine == 0 || endLine > len(lines) { + endLine = len(lines) + } + + before := lines[:startLine-1] + after := lines[endLine:] + + lines = append(before, append(contentLines, after...)...) + newContent := strings.Join(lines, "\n") + + // Join the lines and write back to the file + err = os.WriteFile(path, []byte(newContent), 0644) + if err != nil { + return resultToJson(FunctionResult{Message: fmt.Sprintf("Could not write to path: %s", err.Error())}) + } + + return resultToJson(FunctionResult{Result: newContent}) + +} diff --git a/pkg/cli/openai.go b/pkg/cli/openai.go index 0cc1196..c03cf30 100644 --- a/pkg/cli/openai.go +++ b/pkg/cli/openai.go @@ -2,7 +2,10 @@ package cli import ( "context" + "database/sql" + "encoding/json" "errors" + "fmt" "io" "strings" @@ -12,22 +15,44 @@ import ( func CreateChatCompletionRequest(model string, messages []Message, maxTokens int) openai.ChatCompletionRequest { chatCompletionMessages := []openai.ChatCompletionMessage{} for _, m := range messages { - chatCompletionMessages = append(chatCompletionMessages, openai.ChatCompletionMessage{ + message := openai.ChatCompletionMessage{ Role: m.Role, Content: m.OriginalContent, - }) + } + if m.ToolCallID.Valid { + message.ToolCallID = m.ToolCallID.String + } + if m.ToolCalls.Valid { + // unmarshal directly into chatMessage.ToolCalls + err := json.Unmarshal([]byte(m.ToolCalls.String), &message.ToolCalls) + if err != nil { + // TODO: handle, this shouldn't really happen since + // we only save the successfully marshal'd data to database + fmt.Printf("Error unmarshalling the tool_calls JSON: %v\n", err) + } + } + chatCompletionMessages = append(chatCompletionMessages, message) + } + + var tools []openai.Tool + for _, t := range AvailableTools { + // TODO: support some way to limit which tools are available per-request + tools = append(tools, t.Tool) } return openai.ChatCompletionRequest{ - Model: model, - Messages: chatCompletionMessages, - MaxTokens: maxTokens, - N: 1, // limit responses to 1 "choice". we use choices[0] to reference it + Model: model, + Messages: chatCompletionMessages, + MaxTokens: maxTokens, + N: 1, // limit responses to 1 "choice". we use choices[0] to reference it + Tools: tools, + ToolChoice: "auto", // TODO: allow limiting/forcing which function is called? } } // CreateChatCompletion submits a Chat Completion API request and returns the -// response. +// response. CreateChatCompletion will recursively call itself in the case of +// tool calls, until a response is received with the final user-facing output. func CreateChatCompletion(model string, messages []Message, maxTokens int) (string, error) { client := openai.NewClient(*config.OpenAI.APIKey) req := CreateChatCompletionRequest(model, messages, maxTokens) @@ -36,7 +61,32 @@ func CreateChatCompletion(model string, messages []Message, maxTokens int) (stri return "", err } - return resp.Choices[0].Message.Content, nil + choice := resp.Choices[0] + + if len(choice.Message.ToolCalls) > 0 { + if choice.Message.Content != "" { + return "", fmt.Errorf("Model replied with user-facing content in addition to tool calls. Unsupported.") + } + + // Append the assistant's reply with its request for tool calls + toolCallJson, _ := json.Marshal(choice.Message.ToolCalls) + messages = append(messages, Message{ + Role: "assistant", + ToolCalls: sql.NullString{String: string(toolCallJson), Valid: true}, + }) + + toolReplies, err := ExecuteToolCalls(choice.Message.ToolCalls) + if err != nil { + return "", err + } + + // Recurse into CreateChatCompletion with the tool call replies added + // to the original messages + return CreateChatCompletion(model, append(messages, toolReplies...), maxTokens) + } + + // Return the user-facing message. + return choice.Message.Content, nil } // CreateChatCompletionStream submits a streaming Chat Completion API request @@ -52,7 +102,10 @@ func CreateChatCompletionStream(model string, messages []Message, maxTokens int, } defer stream.Close() - sb := strings.Builder{} + content := strings.Builder{} + toolCalls := []openai.ToolCall{} + + // Iterate stream segments for { response, e := stream.Recv() if errors.Is(e, io.EOF) { @@ -63,9 +116,47 @@ func CreateChatCompletionStream(model string, messages []Message, maxTokens int, err = e break } - chunk := response.Choices[0].Delta.Content - output <- chunk - sb.WriteString(chunk) + + delta := response.Choices[0].Delta + if len(delta.ToolCalls) > 0 { + // Construct streamed tool_call arguments + for _, tc := range delta.ToolCalls { + if tc.Index == nil { + return "", fmt.Errorf("Unexpected nil index for streamed tool call.") + } + if len(toolCalls) <= *tc.Index { + toolCalls = append(toolCalls, tc) + } else { + toolCalls[*tc.Index].Function.Arguments += tc.Function.Arguments + } + } + } else { + output <- delta.Content + content.WriteString(delta.Content) + } } - return sb.String(), err + + if len(toolCalls) > 0 { + if content.String() != "" { + return "", fmt.Errorf("Model replied with user-facing content in addition to tool calls. Unsupported.") + } + + // Append the assistant's reply with its request for tool calls + toolCallJson, _ := json.Marshal(toolCalls) + messages = append(messages, Message{ + Role: "assistant", + ToolCalls: sql.NullString{String: string(toolCallJson), Valid: true}, + }) + + toolReplies, err := ExecuteToolCalls(toolCalls) + if err != nil { + return "", err + } + + // Recurse into CreateChatCompletionStream with the tool call replies + // added to the original messages + return CreateChatCompletionStream(model, append(messages, toolReplies...), maxTokens, output) + } + + return content.String(), err } diff --git a/pkg/cli/store.go b/pkg/cli/store.go index 6518a8d..2ed7909 100644 --- a/pkg/cli/store.go +++ b/pkg/cli/store.go @@ -23,8 +23,10 @@ type Message struct { ConversationID uint `gorm:"foreignKey:ConversationID"` Conversation Conversation OriginalContent string - Role string // 'user' or 'assistant' + Role string // one of: 'user', 'assistant', 'tool' CreatedAt time.Time + ToolCallID sql.NullString + ToolCalls sql.NullString // a json-encoded array of tool calls from the model } type Conversation struct {