diff --git a/pkg/cli/functions.go b/pkg/cli/functions.go new file mode 100644 index 0000000..b49afb1 --- /dev/null +++ b/pkg/cli/functions.go @@ -0,0 +1,168 @@ +package cli + +import ( + "database/sql" + "encoding/json" + "fmt" + "os" + "path/filepath" + + openai "github.com/sashabaranov/go-openai" +) + +type FunctionResult struct { + Message string `json:"message"` + Result any `json:"result,omitempty"` +} + +type FunctionParameter struct { + Type string `json:"type"` // "string", "integer", "boolean" + Description string `json:"description"` + Enum []string `json:"enum,omitempty"` +} + +type FunctionParameters struct { + Type string `json:"type"` // "object" + Properties map[string]FunctionParameter `json:"properties"` + Required []string `json:"required,omitempty"` // required function parameter names +} + +type AvailableTool struct { + openai.Tool + // The tool's implementation. Returns a string, as tool call results + // are treated as normal messages with string contents. + Impl func(arguments map[string]interface{}) (string, error) +} + +var AvailableTools = map[string]AvailableTool{ + "read_dir": { + Tool: openai.Tool{Type: "function", Function: openai.FunctionDefinition{ + Name: "read_dir", + Description: `Return the contents of the CWD (current working directory). + +Results are returned as JSON in the following format: +{ + "message": "success", // "success" if successful, or a different message indicating failure + "result": [ + {"name": "a_file", "type": "file", "length": 123}, + {"name": "a_directory", "type": "dir", "length": 5}, + ... // more files or directories + ] +} + +For type: file, length represents the size (in bytes) of the file. +For type: dir, length represents the number of entries in that directory.`, + Parameters: FunctionParameters{ + Type: "object", + Properties: map[string]FunctionParameter{ + "relative_dir": { + Type: "string", + Description: "If set, read the contents of a directory relative to the current one.", + }, + }, + }, + }}, + Impl: func(args map[string]interface{}) (string, error) { + var relativeDir string + tmp, ok := args["relative_dir"] + if ok { + relativeDir, ok = tmp.(string) + if !ok { + return "", fmt.Errorf("Invalid relative_dir in function arguments: %v", tmp) + } + } + return ReadDir(relativeDir), nil + }, + }, +} + +func resultToJson(result FunctionResult) string { + if result.Message == "" { + // When message not supplied, assume success + result.Message = "success" + } + + jsonBytes, err := json.Marshal(result) + if err != nil { + fmt.Printf("Could not marshal FunctionResult to JSON: %v\n", err) + } + return string(jsonBytes) +} + +// ExecuteToolCalls handles the execution of all tool_calls provided, and +// returns their results formatted as []Message(s) with role: 'tool' and. +func ExecuteToolCalls(toolCalls []openai.ToolCall) ([]Message, error) { + var toolResults []Message + for _, toolCall := range toolCalls { + if toolCall.Type != "function" { + // unsupported tool type + continue + } + + tool, ok := AvailableTools[toolCall.Function.Name] + if !ok { + return nil, fmt.Errorf("Requested tool '%s' does not exist. Hallucination?", toolCall.Function.Name) + } + + var functionArgs map[string]interface{} + err := json.Unmarshal([]byte(toolCall.Function.Arguments), &functionArgs) + if err != nil { + return nil, fmt.Errorf("Could not unmarshal tool arguments. Malformed JSON? Error: %v", err) + } + + // TODO: ability to silence this + fmt.Fprintf(os.Stderr, "INFO: Executing tool '%s' with args %s\n", toolCall.Function.Name, toolCall.Function.Arguments) + + // Execute the tool + toolResult, err := tool.Impl(functionArgs) + if err != nil { + // This can happen if the model missed or supplied invalid tool args + return nil, fmt.Errorf("Tool '%s' error: %v\n", toolCall.Function.Name, err) + } + + toolResults = append(toolResults, Message{ + Role: "tool", + OriginalContent: toolResult, + ToolCallID: sql.NullString{String: toolCall.ID, Valid: true}, + // name is not required since the introduction of ToolCallID + // hypothesis: by setting it, we inform the model of what a + // function's purpose was if future requests omit the function + // definition + }) + } + return toolResults, nil +} + +func ReadDir(path string) string { + // TODO: ensure it is not possible to escape to directories above CWD + // TODO: implement whitelist - list of directories which model is allowed to work in + targetPath := filepath.Join(".", path) + files, err := os.ReadDir(targetPath) + if err != nil { + return resultToJson(FunctionResult{ + Message: err.Error(), + }) + } + + var dirContents []map[string]interface{} + for _, f := range files { + info, _ := f.Info() + + contentType := "file" + length := info.Size() + + if info.IsDir() { + contentType = "dir" + subdirfiles, _ := os.ReadDir(filepath.Join(".", path, info.Name())) + length = int64(len(subdirfiles)) + } + + dirContents = append(dirContents, map[string]interface{}{ + "name": f.Name(), + "type": contentType, + "length": length, + }) + } + + return resultToJson(FunctionResult{Result: dirContents}) +} diff --git a/pkg/cli/openai.go b/pkg/cli/openai.go index 0cc1196..d490e44 100644 --- a/pkg/cli/openai.go +++ b/pkg/cli/openai.go @@ -2,7 +2,10 @@ package cli import ( "context" + "database/sql" + "encoding/json" "errors" + "fmt" "io" "strings" @@ -12,22 +15,44 @@ import ( func CreateChatCompletionRequest(model string, messages []Message, maxTokens int) openai.ChatCompletionRequest { chatCompletionMessages := []openai.ChatCompletionMessage{} for _, m := range messages { - chatCompletionMessages = append(chatCompletionMessages, openai.ChatCompletionMessage{ + message := openai.ChatCompletionMessage{ Role: m.Role, Content: m.OriginalContent, - }) + } + if m.ToolCallID.Valid { + message.ToolCallID = m.ToolCallID.String + } + if m.ToolCalls.Valid { + // unmarshal directly into chatMessage.ToolCalls + err := json.Unmarshal([]byte(m.ToolCalls.String), &message.ToolCalls) + if err != nil { + // TODO: handle, this shouldn't really happen since + // we only save the successfully marshal'd data to database + fmt.Printf("Error unmarshalling the tool_calls JSON: %v\n", err) + } + } + chatCompletionMessages = append(chatCompletionMessages, message) + } + + var tools []openai.Tool + for _, t := range AvailableTools { + // TODO: support some way to limit which tools are available per-request + tools = append(tools, t.Tool) } return openai.ChatCompletionRequest{ - Model: model, - Messages: chatCompletionMessages, - MaxTokens: maxTokens, - N: 1, // limit responses to 1 "choice". we use choices[0] to reference it + Model: model, + Messages: chatCompletionMessages, + MaxTokens: maxTokens, + N: 1, // limit responses to 1 "choice". we use choices[0] to reference it + Tools: tools, + ToolChoice: "auto", // TODO: allow limiting/forcing which function is called? } } // CreateChatCompletion submits a Chat Completion API request and returns the -// response. +// response. CreateChatCompletion will recursively call itself in the case of +// tool calls, until a response is received with the final user-facing output. func CreateChatCompletion(model string, messages []Message, maxTokens int) (string, error) { client := openai.NewClient(*config.OpenAI.APIKey) req := CreateChatCompletionRequest(model, messages, maxTokens) @@ -36,7 +61,33 @@ func CreateChatCompletion(model string, messages []Message, maxTokens int) (stri return "", err } - return resp.Choices[0].Message.Content, nil + choice := resp.Choices[0] + + if len(choice.Message.ToolCalls) > 0 { + if choice.Message.Content != "" { + return "", fmt.Errorf("Model replied with user-facing content in addition to tool calls. Unsupported.") + } + + // Append the assistant's reply with its request for tool calls + toolCallJson, _ := json.Marshal(choice.Message.ToolCalls) + messages = append(messages, Message{ + ConversationID: messages[0].ConversationID, + Role: "assistant", + ToolCalls: sql.NullString{String: string(toolCallJson), Valid: true}, + }) + + toolReplies, err := ExecuteToolCalls(choice.Message.ToolCalls) + if err != nil { + return "", err + } + + // Recurse into CreateChatCompletion with the tool call replies added + // to the original messages + return CreateChatCompletion(model, append(messages, toolReplies...), maxTokens) + } + + // Return the user-facing message. + return choice.Message.Content, nil } // CreateChatCompletionStream submits a streaming Chat Completion API request diff --git a/pkg/cli/store.go b/pkg/cli/store.go index 6518a8d..2ed7909 100644 --- a/pkg/cli/store.go +++ b/pkg/cli/store.go @@ -23,8 +23,10 @@ type Message struct { ConversationID uint `gorm:"foreignKey:ConversationID"` Conversation Conversation OriginalContent string - Role string // 'user' or 'assistant' + Role string // one of: 'user', 'assistant', 'tool' CreatedAt time.Time + ToolCallID sql.NullString + ToolCalls sql.NullString // a json-encoded array of tool calls from the model } type Conversation struct {