Compare commits
8 Commits
3859084fd8
...
6f9b79afa1
Author | SHA1 | Date | |
---|---|---|---|
6f9b79afa1 | |||
42f7b7aa29 | |||
9976c59f58 | |||
4b85b005dd | |||
59487d5721 | |||
1fc0af56df | |||
fa27f83630 | |||
b229c42811 |
518
pkg/cli/functions.go
Normal file
518
pkg/cli/functions.go
Normal file
@ -0,0 +1,518 @@
|
||||
package cli
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
|
||||
openai "github.com/sashabaranov/go-openai"
|
||||
)
|
||||
|
||||
type FunctionResult struct {
|
||||
Message string `json:"message"`
|
||||
Result any `json:"result,omitempty"`
|
||||
}
|
||||
|
||||
type FunctionParameter struct {
|
||||
Type string `json:"type"` // "string", "integer", "boolean"
|
||||
Description string `json:"description"`
|
||||
Enum []string `json:"enum,omitempty"`
|
||||
}
|
||||
|
||||
type FunctionParameters struct {
|
||||
Type string `json:"type"` // "object"
|
||||
Properties map[string]FunctionParameter `json:"properties"`
|
||||
Required []string `json:"required,omitempty"` // required function parameter names
|
||||
}
|
||||
|
||||
type AvailableTool struct {
|
||||
openai.Tool
|
||||
// The tool's implementation. Returns a string, as tool call results
|
||||
// are treated as normal messages with string contents.
|
||||
Impl func(arguments map[string]interface{}) (string, error)
|
||||
}
|
||||
|
||||
var AvailableTools = map[string]AvailableTool{
|
||||
"read_dir": {
|
||||
Tool: openai.Tool{Type: "function", Function: openai.FunctionDefinition{
|
||||
Name: "read_dir",
|
||||
Description: `Return the contents of the CWD (current working directory).
|
||||
|
||||
Results are returned as JSON in the following format:
|
||||
{
|
||||
"message": "success", // if successful, or a different message indicating failure
|
||||
// result may be an empty array if there are no files in the directory
|
||||
"result": [
|
||||
{"name": "a_file", "type": "file", "size": 123},
|
||||
{"name": "a_directory/", "type": "dir", "size": 11},
|
||||
... // more files or directories
|
||||
]
|
||||
}
|
||||
|
||||
For files, size represents the size (in bytes) of the file.
|
||||
For directories, size represents the number of entries in that directory.`,
|
||||
Parameters: FunctionParameters{
|
||||
Type: "object",
|
||||
Properties: map[string]FunctionParameter{
|
||||
"relative_dir": {
|
||||
Type: "string",
|
||||
Description: "If set, read the contents of a directory relative to the current one.",
|
||||
},
|
||||
},
|
||||
},
|
||||
}},
|
||||
Impl: func(args map[string]interface{}) (string, error) {
|
||||
var relativeDir string
|
||||
tmp, ok := args["relative_dir"]
|
||||
if ok {
|
||||
relativeDir, ok = tmp.(string)
|
||||
if !ok {
|
||||
return "", fmt.Errorf("Invalid relative_dir in function arguments: %v", tmp)
|
||||
}
|
||||
}
|
||||
return ReadDir(relativeDir), nil
|
||||
},
|
||||
},
|
||||
"read_file": {
|
||||
Tool: openai.Tool{Type: "function", Function: openai.FunctionDefinition{
|
||||
Name: "read_file",
|
||||
Description: `Read the contents of a text file relative to the current working directory.
|
||||
|
||||
Each line of the file is prefixed with its line number and a tabs (\t) to make
|
||||
it make it easier to see which lines to change for other modifications.
|
||||
|
||||
Example:
|
||||
{
|
||||
"message": "success", // if successful, or a different message indicating failure
|
||||
"result": "1\tthe contents\n2\tof the file\n"
|
||||
}`,
|
||||
Parameters: FunctionParameters{
|
||||
Type: "object",
|
||||
Properties: map[string]FunctionParameter{
|
||||
"path": {
|
||||
Type: "string",
|
||||
Description: "Path to a file within the current working directory to read.",
|
||||
},
|
||||
},
|
||||
Required: []string{"path"},
|
||||
},
|
||||
}},
|
||||
Impl: func(args map[string]interface{}) (string, error) {
|
||||
tmp, ok := args["path"]
|
||||
if !ok {
|
||||
return "", fmt.Errorf("Path parameter to read_file was not included.")
|
||||
}
|
||||
path, ok := tmp.(string)
|
||||
if !ok {
|
||||
return "", fmt.Errorf("Invalid path in function arguments: %v", tmp)
|
||||
}
|
||||
return ReadFile(path), nil
|
||||
},
|
||||
},
|
||||
"write_file": {
|
||||
Tool: openai.Tool{Type: "function", Function: openai.FunctionDefinition{
|
||||
Name: "write_file",
|
||||
Description: `Write the provided contents to a file relative to the current working directory.
|
||||
|
||||
Result is returned as JSON in the following format:
|
||||
{
|
||||
"message": "success", // if successful, or a different message indicating failure
|
||||
}`,
|
||||
Parameters: FunctionParameters{
|
||||
Type: "object",
|
||||
Properties: map[string]FunctionParameter{
|
||||
"path": {
|
||||
Type: "string",
|
||||
Description: "Path to a file within the current working directory to write to.",
|
||||
},
|
||||
"content": {
|
||||
Type: "string",
|
||||
Description: "The content to write to the file. Overwrites any existing content!",
|
||||
},
|
||||
},
|
||||
Required: []string{"path", "content"},
|
||||
},
|
||||
}},
|
||||
Impl: func(args map[string]interface{}) (string, error) {
|
||||
tmp, ok := args["path"]
|
||||
if !ok {
|
||||
return "", fmt.Errorf("Path parameter to write_file was not included.")
|
||||
}
|
||||
path, ok := tmp.(string)
|
||||
if !ok {
|
||||
return "", fmt.Errorf("Invalid path in function arguments: %v", tmp)
|
||||
}
|
||||
tmp, ok = args["content"]
|
||||
if !ok {
|
||||
return "", fmt.Errorf("Content parameter to write_file was not included.")
|
||||
}
|
||||
content, ok := tmp.(string)
|
||||
if !ok {
|
||||
return "", fmt.Errorf("Invalid content in function arguments: %v", tmp)
|
||||
}
|
||||
return WriteFile(path, content), nil
|
||||
},
|
||||
},
|
||||
"modify_file": {
|
||||
Tool: openai.Tool{Type: "function", Function: openai.FunctionDefinition{
|
||||
Name: "modify_file",
|
||||
Description: `Perform complex line-based modifications to a file.
|
||||
|
||||
Line ranges are inclusive. If 'start_line' is specified but 'end_line' is not,
|
||||
'end_line' gets set to the last line of the file.
|
||||
|
||||
To replace or remove a single line, *set start_line and end_line to the same value*
|
||||
|
||||
Examples:
|
||||
* Insert the lines "hello<new line>world" at line 10, preserving other content:
|
||||
{"path": "myfile", "operation": "insert", "start_line": 10, "content": "hello\nworld"}
|
||||
|
||||
* Remove lines 45 up to and including 54:
|
||||
{"path": "myfile", "operation": "remove", "start_line": 45, "end_line": 54}
|
||||
|
||||
* Replace content from line 10 to 25:
|
||||
{"path": "myfile", "operation": "replace", "start_line": 10, "end_line": 25, "content": "i\nwas\nhere"}
|
||||
|
||||
* Replace contents of entire the file:
|
||||
{"path": "myfile", "operation": "replace", "start_line": 0, "content": "i\nwas\nhere"}`,
|
||||
Parameters: FunctionParameters{
|
||||
Type: "object",
|
||||
Properties: map[string]FunctionParameter{
|
||||
"path": {
|
||||
Type: "string",
|
||||
Description: "Path of the file to be modified, relative to the current working directory.",
|
||||
},
|
||||
"operation": {
|
||||
Type: "string",
|
||||
Description: `The the type of modification to make to the file. One of: insert, remove, replace`,
|
||||
},
|
||||
"start_line": {
|
||||
Type: "integer",
|
||||
Description: `(Optional) Where to start making a modification (insert, remove, and replace).`,
|
||||
},
|
||||
"end_line": {
|
||||
Type: "integer",
|
||||
Description: `(Optional) Where to stop making a modification (remove or replace, end of file if omitted).`,
|
||||
},
|
||||
"content": {
|
||||
Type: "string",
|
||||
Description: `(Optional) The content to insert, or replace with.`,
|
||||
},
|
||||
},
|
||||
Required: []string{"path", "operation"},
|
||||
},
|
||||
}},
|
||||
Impl: func(args map[string]interface{}) (string, error) {
|
||||
tmp, ok := args["path"]
|
||||
if !ok {
|
||||
return "", fmt.Errorf("path parameter to write_file was not included.")
|
||||
}
|
||||
path, ok := tmp.(string)
|
||||
if !ok {
|
||||
return "", fmt.Errorf("Invalid path in function arguments: %v", tmp)
|
||||
}
|
||||
tmp, ok = args["operation"]
|
||||
if !ok {
|
||||
return "", fmt.Errorf("operation parameter to modify_file was not included.")
|
||||
}
|
||||
operation, ok := tmp.(string)
|
||||
if !ok {
|
||||
return "", fmt.Errorf("Invalid operation in function arguments: %v", tmp)
|
||||
}
|
||||
var start_line int
|
||||
tmp, ok = args["start_line"]
|
||||
if ok {
|
||||
tmp, ok := tmp.(float64)
|
||||
if !ok {
|
||||
return "", fmt.Errorf("Invalid start_line in function arguments: %v", tmp)
|
||||
}
|
||||
start_line = int(tmp)
|
||||
}
|
||||
var end_line int
|
||||
tmp, ok = args["end_line"]
|
||||
if ok {
|
||||
tmp, ok := tmp.(float64)
|
||||
if !ok {
|
||||
return "", fmt.Errorf("Invalid end_line in function arguments: %v", tmp)
|
||||
}
|
||||
end_line = int(tmp)
|
||||
}
|
||||
var content string
|
||||
tmp, ok = args["content"]
|
||||
if ok {
|
||||
content, ok = tmp.(string)
|
||||
if !ok {
|
||||
return "", fmt.Errorf("Invalid content in function arguments: %v", tmp)
|
||||
}
|
||||
}
|
||||
|
||||
return ModifyFile(path, operation, content, start_line, end_line), nil
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
func resultToJson(result FunctionResult) string {
|
||||
if result.Message == "" {
|
||||
// When message not supplied, assume success
|
||||
result.Message = "success"
|
||||
}
|
||||
|
||||
jsonBytes, err := json.Marshal(result)
|
||||
if err != nil {
|
||||
fmt.Printf("Could not marshal FunctionResult to JSON: %v\n", err)
|
||||
}
|
||||
return string(jsonBytes)
|
||||
}
|
||||
|
||||
// ExecuteToolCalls handles the execution of all tool_calls provided, and
|
||||
// returns their results formatted as []Message(s) with role: 'tool' and.
|
||||
func ExecuteToolCalls(toolCalls []openai.ToolCall) ([]Message, error) {
|
||||
var toolResults []Message
|
||||
for _, toolCall := range toolCalls {
|
||||
if toolCall.Type != "function" {
|
||||
// unsupported tool type
|
||||
continue
|
||||
}
|
||||
|
||||
tool, ok := AvailableTools[toolCall.Function.Name]
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("Requested tool '%s' does not exist. Hallucination?", toolCall.Function.Name)
|
||||
}
|
||||
|
||||
var functionArgs map[string]interface{}
|
||||
err := json.Unmarshal([]byte(toolCall.Function.Arguments), &functionArgs)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("Could not unmarshal tool arguments. Malformed JSON? Error: %v", err)
|
||||
}
|
||||
|
||||
// TODO: ability to silence this
|
||||
fmt.Fprintf(os.Stderr, "INFO: Executing tool '%s' with args %s\n", toolCall.Function.Name, toolCall.Function.Arguments)
|
||||
|
||||
// Execute the tool
|
||||
toolResult, err := tool.Impl(functionArgs)
|
||||
if err != nil {
|
||||
// This can happen if the model missed or supplied invalid tool args
|
||||
return nil, fmt.Errorf("Tool '%s' error: %v\n", toolCall.Function.Name, err)
|
||||
}
|
||||
|
||||
toolResults = append(toolResults, Message{
|
||||
Role: "tool",
|
||||
OriginalContent: toolResult,
|
||||
ToolCallID: sql.NullString{String: toolCall.ID, Valid: true},
|
||||
// name is not required since the introduction of ToolCallID
|
||||
// hypothesis: by setting it, we inform the model of what a
|
||||
// function's purpose was if future requests omit the function
|
||||
// definition
|
||||
})
|
||||
}
|
||||
return toolResults, nil
|
||||
}
|
||||
|
||||
// isPathContained attempts to verify whether `path` is the same as or
|
||||
// contained within `directory`. It is overly cautious, returning false even if
|
||||
// `path` IS contained within `directory`, but the two paths use different
|
||||
// casing, and we happen to be on a case-insensitive filesystem.
|
||||
// This is ultimately to attempt to stop an LLM from going outside of where I
|
||||
// tell it to. Additional layers of security should be considered.. run in a
|
||||
// VM/container.
|
||||
func isPathContained(directory string, path string) (bool, error) {
|
||||
// Clean and resolve symlinks for both paths
|
||||
path, err := filepath.Abs(path)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
|
||||
// check if path exists
|
||||
_, err = os.Stat(path)
|
||||
if err != nil {
|
||||
if !os.IsNotExist(err) {
|
||||
return false, fmt.Errorf("Could not stat path: %v", err)
|
||||
}
|
||||
} else {
|
||||
path, err = filepath.EvalSymlinks(path)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
}
|
||||
|
||||
directory, err = filepath.Abs(directory)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
directory, err = filepath.EvalSymlinks(directory)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
|
||||
// Case insensitive checks
|
||||
if !strings.EqualFold(path, directory) &&
|
||||
!strings.HasPrefix(strings.ToLower(path), strings.ToLower(directory)+string(os.PathSeparator)) {
|
||||
return false, nil
|
||||
}
|
||||
|
||||
return true, nil
|
||||
}
|
||||
|
||||
func isPathWithinCWD(path string) (bool, *FunctionResult) {
|
||||
cwd, err := os.Getwd()
|
||||
if err != nil {
|
||||
return false, &FunctionResult{Message: "Failed to determine current working directory"}
|
||||
}
|
||||
if ok, err := isPathContained(cwd, path); !ok {
|
||||
if err != nil {
|
||||
return false, &FunctionResult{Message: fmt.Sprintf("Could not determine whether path '%s' is within the current working directory: %s", path, err.Error())}
|
||||
}
|
||||
return false, &FunctionResult{Message: fmt.Sprintf("Path '%s' is not within the current working directory", path)}
|
||||
}
|
||||
return true, nil
|
||||
}
|
||||
|
||||
func ReadDir(path string) string {
|
||||
// TODO(?): implement whitelist - list of directories which model is allowed to work in
|
||||
if path == "" {
|
||||
path = "."
|
||||
}
|
||||
ok, res := isPathWithinCWD(path)
|
||||
if !ok {
|
||||
return resultToJson(*res)
|
||||
}
|
||||
|
||||
files, err := os.ReadDir(path)
|
||||
if err != nil {
|
||||
return resultToJson(FunctionResult{
|
||||
Message: err.Error(),
|
||||
})
|
||||
}
|
||||
|
||||
var dirContents []map[string]interface{}
|
||||
for _, f := range files {
|
||||
info, _ := f.Info()
|
||||
|
||||
name := f.Name()
|
||||
if strings.HasPrefix(name, ".") {
|
||||
// skip hidden files
|
||||
continue
|
||||
}
|
||||
|
||||
entryType := "file"
|
||||
size := info.Size()
|
||||
|
||||
if info.IsDir() {
|
||||
name += "/"
|
||||
entryType = "dir"
|
||||
subdirfiles, _ := os.ReadDir(filepath.Join(".", path, info.Name()))
|
||||
size = int64(len(subdirfiles))
|
||||
}
|
||||
|
||||
dirContents = append(dirContents, map[string]interface{}{
|
||||
"name": name,
|
||||
"type": entryType,
|
||||
"size": size,
|
||||
})
|
||||
}
|
||||
|
||||
return resultToJson(FunctionResult{Result: dirContents})
|
||||
}
|
||||
|
||||
func ReadFile(path string) string {
|
||||
ok, res := isPathWithinCWD(path)
|
||||
if !ok {
|
||||
return resultToJson(*res)
|
||||
}
|
||||
data, err := os.ReadFile(path)
|
||||
if err != nil {
|
||||
return resultToJson(FunctionResult{Message: fmt.Sprintf("Could not read path: %s", err.Error())})
|
||||
}
|
||||
|
||||
lines := strings.Split(string(data), "\n")
|
||||
content := strings.Builder{}
|
||||
for i, line := range lines {
|
||||
content.WriteString(fmt.Sprintf("%d\t%s\n", i+1, line))
|
||||
}
|
||||
|
||||
return resultToJson(FunctionResult{
|
||||
Result: content.String(),
|
||||
})
|
||||
}
|
||||
|
||||
func WriteFile(path string, content string) string {
|
||||
ok, res := isPathWithinCWD(path)
|
||||
if !ok {
|
||||
return resultToJson(*res)
|
||||
}
|
||||
err := os.WriteFile(path, []byte(content), 0644)
|
||||
if err != nil {
|
||||
return resultToJson(FunctionResult{Message: fmt.Sprintf("Could not write to path: %s", err.Error())})
|
||||
}
|
||||
return resultToJson(FunctionResult{})
|
||||
}
|
||||
|
||||
func ModifyFile(path string, operation string, content string, startLine int, endLine int) string {
|
||||
ok, res := isPathWithinCWD(path)
|
||||
if !ok {
|
||||
return resultToJson(*res)
|
||||
}
|
||||
|
||||
// Read the existing file's content
|
||||
data, err := os.ReadFile(path)
|
||||
if err != nil {
|
||||
if !os.IsNotExist(err) {
|
||||
return resultToJson(FunctionResult{Message: fmt.Sprintf("Could not read path: %s", err.Error())})
|
||||
}
|
||||
_, err = os.Create(path)
|
||||
if err != nil {
|
||||
return resultToJson(FunctionResult{Message: fmt.Sprintf("Could not create new file: %s", err.Error())})
|
||||
}
|
||||
data = []byte{}
|
||||
}
|
||||
|
||||
if startLine < 0 {
|
||||
return resultToJson(FunctionResult{Message: "start_line cannot be less than 0"})
|
||||
}
|
||||
|
||||
// Split the content by newline to process lines
|
||||
lines := strings.Split(string(data), "\n")
|
||||
contentLines := strings.Split(strings.Trim(content, "\n"), "\n")
|
||||
|
||||
switch operation {
|
||||
case "insert":
|
||||
// Insert new lines
|
||||
before := lines[:startLine-1]
|
||||
after := append(contentLines, lines[startLine-1:]...)
|
||||
lines = append(before, after...)
|
||||
case "remove":
|
||||
// Remove lines
|
||||
if endLine == 0 || endLine > len(lines) {
|
||||
endLine = len(lines)
|
||||
}
|
||||
|
||||
lines = append(lines[:startLine-1], lines[endLine:]...)
|
||||
case "replace":
|
||||
// Replace the lines between start_line and end_line
|
||||
if endLine == 0 || endLine > len(lines) {
|
||||
endLine = len(lines)
|
||||
}
|
||||
if startLine == 0 {
|
||||
// model likely trying to replace contents, must start at line 1
|
||||
startLine = 1
|
||||
}
|
||||
before := lines[:startLine-1]
|
||||
after := lines[endLine:]
|
||||
lines = append(before, append(contentLines, after...)...)
|
||||
default:
|
||||
return resultToJson(FunctionResult{Message: fmt.Sprintf("Invalid operation: %s", operation)})
|
||||
}
|
||||
|
||||
newContent := strings.Join(lines, "\n")
|
||||
|
||||
// Join the lines and write back to the file
|
||||
err = os.WriteFile(path, []byte(newContent), 0644)
|
||||
if err != nil {
|
||||
return resultToJson(FunctionResult{Message: fmt.Sprintf("Could not write to path: %s", err.Error())})
|
||||
}
|
||||
|
||||
return resultToJson(FunctionResult{Result: newContent})
|
||||
}
|
@ -2,7 +2,10 @@ package cli
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"strings"
|
||||
|
||||
@ -12,10 +15,29 @@ import (
|
||||
func CreateChatCompletionRequest(model string, messages []Message, maxTokens int) openai.ChatCompletionRequest {
|
||||
chatCompletionMessages := []openai.ChatCompletionMessage{}
|
||||
for _, m := range messages {
|
||||
chatCompletionMessages = append(chatCompletionMessages, openai.ChatCompletionMessage{
|
||||
message := openai.ChatCompletionMessage{
|
||||
Role: m.Role,
|
||||
Content: m.OriginalContent,
|
||||
})
|
||||
}
|
||||
if m.ToolCallID.Valid {
|
||||
message.ToolCallID = m.ToolCallID.String
|
||||
}
|
||||
if m.ToolCalls.Valid {
|
||||
// unmarshal directly into chatMessage.ToolCalls
|
||||
err := json.Unmarshal([]byte(m.ToolCalls.String), &message.ToolCalls)
|
||||
if err != nil {
|
||||
// TODO: handle, this shouldn't really happen since
|
||||
// we only save the successfully marshal'd data to database
|
||||
fmt.Printf("Error unmarshalling the tool_calls JSON: %v\n", err)
|
||||
}
|
||||
}
|
||||
chatCompletionMessages = append(chatCompletionMessages, message)
|
||||
}
|
||||
|
||||
var tools []openai.Tool
|
||||
for _, t := range AvailableTools {
|
||||
// TODO: support some way to limit which tools are available per-request
|
||||
tools = append(tools, t.Tool)
|
||||
}
|
||||
|
||||
return openai.ChatCompletionRequest{
|
||||
@ -23,11 +45,14 @@ func CreateChatCompletionRequest(model string, messages []Message, maxTokens int
|
||||
Messages: chatCompletionMessages,
|
||||
MaxTokens: maxTokens,
|
||||
N: 1, // limit responses to 1 "choice". we use choices[0] to reference it
|
||||
Tools: tools,
|
||||
ToolChoice: "auto", // TODO: allow limiting/forcing which function is called?
|
||||
}
|
||||
}
|
||||
|
||||
// CreateChatCompletion submits a Chat Completion API request and returns the
|
||||
// response.
|
||||
// response. CreateChatCompletion will recursively call itself in the case of
|
||||
// tool calls, until a response is received with the final user-facing output.
|
||||
func CreateChatCompletion(model string, messages []Message, maxTokens int) (string, error) {
|
||||
client := openai.NewClient(*config.OpenAI.APIKey)
|
||||
req := CreateChatCompletionRequest(model, messages, maxTokens)
|
||||
@ -36,7 +61,32 @@ func CreateChatCompletion(model string, messages []Message, maxTokens int) (stri
|
||||
return "", err
|
||||
}
|
||||
|
||||
return resp.Choices[0].Message.Content, nil
|
||||
choice := resp.Choices[0]
|
||||
|
||||
if len(choice.Message.ToolCalls) > 0 {
|
||||
if choice.Message.Content != "" {
|
||||
return "", fmt.Errorf("Model replied with user-facing content in addition to tool calls. Unsupported.")
|
||||
}
|
||||
|
||||
// Append the assistant's reply with its request for tool calls
|
||||
toolCallJson, _ := json.Marshal(choice.Message.ToolCalls)
|
||||
messages = append(messages, Message{
|
||||
Role: "assistant",
|
||||
ToolCalls: sql.NullString{String: string(toolCallJson), Valid: true},
|
||||
})
|
||||
|
||||
toolReplies, err := ExecuteToolCalls(choice.Message.ToolCalls)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
// Recurse into CreateChatCompletion with the tool call replies added
|
||||
// to the original messages
|
||||
return CreateChatCompletion(model, append(messages, toolReplies...), maxTokens)
|
||||
}
|
||||
|
||||
// Return the user-facing message.
|
||||
return choice.Message.Content, nil
|
||||
}
|
||||
|
||||
// CreateChatCompletionStream submits a streaming Chat Completion API request
|
||||
@ -52,7 +102,10 @@ func CreateChatCompletionStream(model string, messages []Message, maxTokens int,
|
||||
}
|
||||
defer stream.Close()
|
||||
|
||||
sb := strings.Builder{}
|
||||
content := strings.Builder{}
|
||||
toolCalls := []openai.ToolCall{}
|
||||
|
||||
// Iterate stream segments
|
||||
for {
|
||||
response, e := stream.Recv()
|
||||
if errors.Is(e, io.EOF) {
|
||||
@ -63,9 +116,47 @@ func CreateChatCompletionStream(model string, messages []Message, maxTokens int,
|
||||
err = e
|
||||
break
|
||||
}
|
||||
chunk := response.Choices[0].Delta.Content
|
||||
output <- chunk
|
||||
sb.WriteString(chunk)
|
||||
|
||||
delta := response.Choices[0].Delta
|
||||
if len(delta.ToolCalls) > 0 {
|
||||
// Construct streamed tool_call arguments
|
||||
for _, tc := range delta.ToolCalls {
|
||||
if tc.Index == nil {
|
||||
return "", fmt.Errorf("Unexpected nil index for streamed tool call.")
|
||||
}
|
||||
return sb.String(), err
|
||||
if len(toolCalls) <= *tc.Index {
|
||||
toolCalls = append(toolCalls, tc)
|
||||
} else {
|
||||
toolCalls[*tc.Index].Function.Arguments += tc.Function.Arguments
|
||||
}
|
||||
}
|
||||
} else {
|
||||
output <- delta.Content
|
||||
content.WriteString(delta.Content)
|
||||
}
|
||||
}
|
||||
|
||||
if len(toolCalls) > 0 {
|
||||
if content.String() != "" {
|
||||
return "", fmt.Errorf("Model replied with user-facing content in addition to tool calls. Unsupported.")
|
||||
}
|
||||
|
||||
// Append the assistant's reply with its request for tool calls
|
||||
toolCallJson, _ := json.Marshal(toolCalls)
|
||||
messages = append(messages, Message{
|
||||
Role: "assistant",
|
||||
ToolCalls: sql.NullString{String: string(toolCallJson), Valid: true},
|
||||
})
|
||||
|
||||
toolReplies, err := ExecuteToolCalls(toolCalls)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
// Recurse into CreateChatCompletionStream with the tool call replies
|
||||
// added to the original messages
|
||||
return CreateChatCompletionStream(model, append(messages, toolReplies...), maxTokens, output)
|
||||
}
|
||||
|
||||
return content.String(), err
|
||||
}
|
||||
|
@ -23,8 +23,10 @@ type Message struct {
|
||||
ConversationID uint `gorm:"foreignKey:ConversationID"`
|
||||
Conversation Conversation
|
||||
OriginalContent string
|
||||
Role string // 'user' or 'assistant'
|
||||
Role string // one of: 'user', 'assistant', 'tool'
|
||||
CreatedAt time.Time
|
||||
ToolCallID sql.NullString
|
||||
ToolCalls sql.NullString // a json-encoded array of tool calls from the model
|
||||
}
|
||||
|
||||
type Conversation struct {
|
||||
|
Loading…
Reference in New Issue
Block a user