Add initial support for tools

So far only supported on the non-streaming endpoint.

Added the `read_dir` tool for reading contents from paths relative to the
current working directory.
This commit is contained in:
Matt Low 2023-11-26 00:55:18 +00:00
parent 1e63c09907
commit bf1f23b1d6
3 changed files with 230 additions and 9 deletions

168
pkg/cli/functions.go Normal file
View File

@ -0,0 +1,168 @@
package cli
import (
"database/sql"
"encoding/json"
"fmt"
"os"
"path/filepath"
openai "github.com/sashabaranov/go-openai"
)
type FunctionResult struct {
Message string `json:"message"`
Result any `json:"result,omitempty"`
}
type FunctionParameter struct {
Type string `json:"type"` // "string", "integer", "boolean"
Description string `json:"description"`
Enum []string `json:"enum,omitempty"`
}
type FunctionParameters struct {
Type string `json:"type"` // "object"
Properties map[string]FunctionParameter `json:"properties"`
Required []string `json:"required,omitempty"` // required function parameter names
}
type AvailableTool struct {
openai.Tool
// The tool's implementation. Returns a string, as tool call results
// are treated as normal messages with string contents.
Impl func(arguments map[string]interface{}) (string, error)
}
var AvailableTools = map[string]AvailableTool{
"read_dir": {
Tool: openai.Tool{Type: "function", Function: openai.FunctionDefinition{
Name: "read_dir",
Description: `Return the contents of the CWD (current working directory).
Results are returned as JSON in the following format:
{
"message": "success", // "success" if successful, or a different message indicating failure
"result": [
{"name": "a_file", "type": "file", "length": 123},
{"name": "a_directory", "type": "dir", "length": 5},
... // more files or directories
]
}
For type: file, length represents the size (in bytes) of the file.
For type: dir, length represents the number of entries in that directory.`,
Parameters: FunctionParameters{
Type: "object",
Properties: map[string]FunctionParameter{
"relative_dir": {
Type: "string",
Description: "If set, read the contents of a directory relative to the current one.",
},
},
},
}},
Impl: func(args map[string]interface{}) (string, error) {
var relativeDir string
tmp, ok := args["relative_dir"]
if ok {
relativeDir, ok = tmp.(string)
if !ok {
return "", fmt.Errorf("Invalid relative_dir in function arguments: %v", tmp)
}
}
return ReadDir(relativeDir), nil
},
},
}
func resultToJson(result FunctionResult) string {
if result.Message == "" {
// When message not supplied, assume success
result.Message = "success"
}
jsonBytes, err := json.Marshal(result)
if err != nil {
fmt.Printf("Could not marshal FunctionResult to JSON: %v\n", err)
}
return string(jsonBytes)
}
// ExecuteToolCalls handles the execution of all tool_calls provided, and
// returns their results formatted as []Message(s) with role: 'tool' and.
func ExecuteToolCalls(toolCalls []openai.ToolCall) ([]Message, error) {
var toolResults []Message
for _, toolCall := range toolCalls {
if toolCall.Type != "function" {
// unsupported tool type
continue
}
tool, ok := AvailableTools[toolCall.Function.Name]
if !ok {
return nil, fmt.Errorf("Requested tool '%s' does not exist. Hallucination?", toolCall.Function.Name)
}
var functionArgs map[string]interface{}
err := json.Unmarshal([]byte(toolCall.Function.Arguments), &functionArgs)
if err != nil {
return nil, fmt.Errorf("Could not unmarshal tool arguments. Malformed JSON? Error: %v", err)
}
// TODO: ability to silence this
fmt.Fprintf(os.Stderr, "INFO: Executing tool '%s' with args %s\n", toolCall.Function.Name, toolCall.Function.Arguments)
// Execute the tool
toolResult, err := tool.Impl(functionArgs)
if err != nil {
// This can happen if the model missed or supplied invalid tool args
return nil, fmt.Errorf("Tool '%s' error: %v\n", toolCall.Function.Name, err)
}
toolResults = append(toolResults, Message{
Role: "tool",
OriginalContent: toolResult,
ToolCallID: sql.NullString{String: toolCall.ID, Valid: true},
// name is not required since the introduction of ToolCallID
// hypothesis: by setting it, we inform the model of what a
// function's purpose was if future requests omit the function
// definition
})
}
return toolResults, nil
}
func ReadDir(path string) string {
// TODO: ensure it is not possible to escape to directories above CWD
// TODO: implement whitelist - list of directories which model is allowed to work in
targetPath := filepath.Join(".", path)
files, err := os.ReadDir(targetPath)
if err != nil {
return resultToJson(FunctionResult{
Message: err.Error(),
})
}
var dirContents []map[string]interface{}
for _, f := range files {
info, _ := f.Info()
contentType := "file"
length := info.Size()
if info.IsDir() {
contentType = "dir"
subdirfiles, _ := os.ReadDir(filepath.Join(".", path, info.Name()))
length = int64(len(subdirfiles))
}
dirContents = append(dirContents, map[string]interface{}{
"name": f.Name(),
"type": contentType,
"length": length,
})
}
return resultToJson(FunctionResult{Result: dirContents})
}

View File

@ -2,7 +2,10 @@ package cli
import ( import (
"context" "context"
"database/sql"
"encoding/json"
"errors" "errors"
"fmt"
"io" "io"
"strings" "strings"
@ -12,10 +15,29 @@ import (
func CreateChatCompletionRequest(model string, messages []Message, maxTokens int) openai.ChatCompletionRequest { func CreateChatCompletionRequest(model string, messages []Message, maxTokens int) openai.ChatCompletionRequest {
chatCompletionMessages := []openai.ChatCompletionMessage{} chatCompletionMessages := []openai.ChatCompletionMessage{}
for _, m := range messages { for _, m := range messages {
chatCompletionMessages = append(chatCompletionMessages, openai.ChatCompletionMessage{ message := openai.ChatCompletionMessage{
Role: m.Role, Role: m.Role,
Content: m.OriginalContent, Content: m.OriginalContent,
}) }
if m.ToolCallID.Valid {
message.ToolCallID = m.ToolCallID.String
}
if m.ToolCalls.Valid {
// unmarshal directly into chatMessage.ToolCalls
err := json.Unmarshal([]byte(m.ToolCalls.String), &message.ToolCalls)
if err != nil {
// TODO: handle, this shouldn't really happen since
// we only save the successfully marshal'd data to database
fmt.Printf("Error unmarshalling the tool_calls JSON: %v\n", err)
}
}
chatCompletionMessages = append(chatCompletionMessages, message)
}
var tools []openai.Tool
for _, t := range AvailableTools {
// TODO: support some way to limit which tools are available per-request
tools = append(tools, t.Tool)
} }
return openai.ChatCompletionRequest{ return openai.ChatCompletionRequest{
@ -23,11 +45,14 @@ func CreateChatCompletionRequest(model string, messages []Message, maxTokens int
Messages: chatCompletionMessages, Messages: chatCompletionMessages,
MaxTokens: maxTokens, MaxTokens: maxTokens,
N: 1, // limit responses to 1 "choice". we use choices[0] to reference it N: 1, // limit responses to 1 "choice". we use choices[0] to reference it
Tools: tools,
ToolChoice: "auto", // TODO: allow limiting/forcing which function is called?
} }
} }
// CreateChatCompletion submits a Chat Completion API request and returns the // CreateChatCompletion submits a Chat Completion API request and returns the
// response. // response. CreateChatCompletion will recursively call itself in the case of
// tool calls, until a response is received with the final user-facing output.
func CreateChatCompletion(model string, messages []Message, maxTokens int) (string, error) { func CreateChatCompletion(model string, messages []Message, maxTokens int) (string, error) {
client := openai.NewClient(*config.OpenAI.APIKey) client := openai.NewClient(*config.OpenAI.APIKey)
req := CreateChatCompletionRequest(model, messages, maxTokens) req := CreateChatCompletionRequest(model, messages, maxTokens)
@ -36,7 +61,33 @@ func CreateChatCompletion(model string, messages []Message, maxTokens int) (stri
return "", err return "", err
} }
return resp.Choices[0].Message.Content, nil choice := resp.Choices[0]
if len(choice.Message.ToolCalls) > 0 {
if choice.Message.Content != "" {
return "", fmt.Errorf("Model replied with user-facing content in addition to tool calls. Unsupported.")
}
// Append the assistant's reply with its request for tool calls
toolCallJson, _ := json.Marshal(choice.Message.ToolCalls)
messages = append(messages, Message{
ConversationID: messages[0].ConversationID,
Role: "assistant",
ToolCalls: sql.NullString{String: string(toolCallJson), Valid: true},
})
toolReplies, err := ExecuteToolCalls(choice.Message.ToolCalls)
if err != nil {
return "", err
}
// Recurse into CreateChatCompletion with the tool call replies added
// to the original messages
return CreateChatCompletion(model, append(messages, toolReplies...), maxTokens)
}
// Return the user-facing message.
return choice.Message.Content, nil
} }
// CreateChatCompletionStream submits a streaming Chat Completion API request // CreateChatCompletionStream submits a streaming Chat Completion API request

View File

@ -23,8 +23,10 @@ type Message struct {
ConversationID uint `gorm:"foreignKey:ConversationID"` ConversationID uint `gorm:"foreignKey:ConversationID"`
Conversation Conversation Conversation Conversation
OriginalContent string OriginalContent string
Role string // 'user' or 'assistant' Role string // one of: 'user', 'assistant', 'tool'
CreatedAt time.Time CreatedAt time.Time
ToolCallID sql.NullString
ToolCalls sql.NullString // a json-encoded array of tool calls from the model
} }
type Conversation struct { type Conversation struct {