Compare commits

..

No commits in common. "d32e9421fe541860a57624031a69ab6bf8049654" and "1e63c09907bb1330cda9e1c8dfc1282a1b98a411" have entirely different histories.

6 changed files with 95 additions and 782 deletions

View File

@ -53,10 +53,9 @@ func SystemPrompt() string {
return systemPrompt return systemPrompt
} }
// LLMRequest prompts the LLM with the given messages, writing the response // LLMRequest prompts the LLM with the given Message, writes the (partial)
// to stdout. Returns all reply messages added by the LLM, including any // response to stdout, and returns the (partial) response or any errors.
// function call messages. func LLMRequest(messages []Message) (string, error) {
func LLMRequest(messages []Message) ([]Message, error) {
// receiver receives the reponse from LLM // receiver receives the reponse from LLM
receiver := make(chan string) receiver := make(chan string)
defer close(receiver) defer close(receiver)
@ -64,8 +63,7 @@ func LLMRequest(messages []Message) ([]Message, error) {
// start HandleDelayedContent goroutine to print received data to stdout // start HandleDelayedContent goroutine to print received data to stdout
go HandleDelayedContent(receiver) go HandleDelayedContent(receiver)
var replies []Message response, err := CreateChatCompletionStream(model, messages, maxTokens, receiver)
response, err := CreateChatCompletionStream(model, messages, maxTokens, receiver, &replies)
if response != "" { if response != "" {
if err != nil { if err != nil {
Warn("Received partial response. Error: %v\n", err) Warn("Received partial response. Error: %v\n", err)
@ -75,23 +73,7 @@ func LLMRequest(messages []Message) ([]Message, error) {
fmt.Println() fmt.Println()
} }
return replies, err return response, err
}
func (c *Conversation) GenerateAndSaveReplies(messages []Message) {
replies, err := LLMRequest(messages)
if err != nil {
Fatal("Error fetching LLM response: %v\n", err)
}
for _, reply := range replies {
reply.ConversationID = c.ID
err = store.SaveMessage(&reply)
if err != nil {
Warn("Could not save reply: %v\n", err)
}
}
} }
// InputFromArgsOrEditor returns either the provided input from the args slice // InputFromArgsOrEditor returns either the provided input from the args slice
@ -334,9 +316,23 @@ var replyCmd = &cobra.Command{
messages = append(messages, userReply) messages = append(messages, userReply)
RenderConversation(messages, true) RenderConversation(messages, true)
(&Message{Role: MessageRoleAssistant}).RenderTTY() assistantReply := Message{
ConversationID: conversation.ID,
Role: "assistant",
}
assistantReply.RenderTTY()
conversation.GenerateAndSaveReplies(messages) response, err := LLMRequest(messages)
if err != nil {
Fatal("Error fetching LLM response: %v\n", err)
}
assistantReply.OriginalContent = response
err = store.SaveMessage(&assistantReply)
if err != nil {
Fatal("Could not save assistant reply: %v\n", err)
}
}, },
ValidArgsFunction: func(cmd *cobra.Command, args []string, toComplete string) ([]string, cobra.ShellCompDirective) { ValidArgsFunction: func(cmd *cobra.Command, args []string, toComplete string) ([]string, cobra.ShellCompDirective) {
compMode := cobra.ShellCompDirectiveNoFileComp compMode := cobra.ShellCompDirectiveNoFileComp
@ -383,9 +379,23 @@ var newCmd = &cobra.Command{
} }
RenderConversation(messages, true) RenderConversation(messages, true)
(&Message{Role: MessageRoleAssistant}).RenderTTY() reply := Message{
ConversationID: conversation.ID,
Role: "assistant",
}
reply.RenderTTY()
conversation.GenerateAndSaveReplies(messages) response, err := LLMRequest(messages)
if err != nil {
Fatal("Error fetching LLM response: %v\n", err)
}
reply.OriginalContent = response
err = store.SaveMessage(&reply)
if err != nil {
Fatal("Could not save reply: %v\n", err)
}
err = conversation.GenerateTitle() err = conversation.GenerateTitle()
if err != nil { if err != nil {
@ -451,28 +461,33 @@ var retryCmd = &cobra.Command{
} }
var lastUserMessageIndex int var lastUserMessageIndex int
// walk backwards through conversations to find last user message
for i := len(messages) - 1; i >= 0; i-- { for i := len(messages) - 1; i >= 0; i-- {
if messages[i].Role == MessageRoleUser { if messages[i].Role == "user" {
lastUserMessageIndex = i lastUserMessageIndex = i
break break
} }
if lastUserMessageIndex == 0 {
// haven't found the the last user message yet, delete this one
err = store.DeleteMessage(&messages[i])
if err != nil {
Warn("Could not delete previous reply: %v\n", err)
}
}
} }
messages = messages[:lastUserMessageIndex+1] messages = messages[:lastUserMessageIndex+1]
RenderConversation(messages, true) RenderConversation(messages, true)
(&Message{Role: MessageRoleAssistant}).RenderTTY() assistantReply := Message{
ConversationID: conversation.ID,
Role: "assistant",
}
assistantReply.RenderTTY()
conversation.GenerateAndSaveReplies(messages) response, err := LLMRequest(messages)
if err != nil {
Fatal("Error fetching LLM response: %v\n", err)
}
assistantReply.OriginalContent = response
err = store.SaveMessage(&assistantReply)
if err != nil {
Fatal("Could not save assistant reply: %v\n", err)
}
}, },
ValidArgsFunction: func(cmd *cobra.Command, args []string, toComplete string) ([]string, cobra.ShellCompDirective) { ValidArgsFunction: func(cmd *cobra.Command, args []string, toComplete string) ([]string, cobra.ShellCompDirective) {
compMode := cobra.ShellCompDirectiveNoFileComp compMode := cobra.ShellCompDirectiveNoFileComp
@ -507,9 +522,23 @@ var continueCmd = &cobra.Command{
} }
RenderConversation(messages, true) RenderConversation(messages, true)
(&Message{Role: MessageRoleAssistant}).RenderTTY() assistantReply := Message{
ConversationID: conversation.ID,
Role: "assistant",
}
assistantReply.RenderTTY()
conversation.GenerateAndSaveReplies(messages) response, err := LLMRequest(messages)
if err != nil {
Fatal("Error fetching LLM response: %v\n", err)
}
assistantReply.OriginalContent = response
err = store.SaveMessage(&assistantReply)
if err != nil {
Fatal("Could not save assistant reply: %v\n", err)
}
}, },
ValidArgsFunction: func(cmd *cobra.Command, args []string, toComplete string) ([]string, cobra.ShellCompDirective) { ValidArgsFunction: func(cmd *cobra.Command, args []string, toComplete string) ([]string, cobra.ShellCompDirective) {
compMode := cobra.ShellCompDirectiveNoFileComp compMode := cobra.ShellCompDirectiveNoFileComp

View File

@ -16,7 +16,6 @@ type Config struct {
APIKey *string `yaml:"apiKey" default:"your_key_here"` APIKey *string `yaml:"apiKey" default:"your_key_here"`
DefaultModel *string `yaml:"defaultModel" default:"gpt-4"` DefaultModel *string `yaml:"defaultModel" default:"gpt-4"`
DefaultMaxLength *int `yaml:"defaultMaxLength" default:"256"` DefaultMaxLength *int `yaml:"defaultMaxLength" default:"256"`
EnabledTools []string `yaml:"enabledTools"`
} `yaml:"openai"` } `yaml:"openai"`
Chroma *struct { Chroma *struct {
Style *string `yaml:"style" default:"onedark"` Style *string `yaml:"style" default:"onedark"`

View File

@ -5,43 +5,35 @@ import (
"strings" "strings"
) )
type MessageRole string
const (
MessageRoleUser MessageRole = "user"
MessageRoleAssistant MessageRole = "assistant"
MessageRoleSystem MessageRole = "system"
)
// FriendlyRole returns a human friendly signifier for the message's role. // FriendlyRole returns a human friendly signifier for the message's role.
func (m *Message) FriendlyRole() string { func (m *Message) FriendlyRole() string {
var friendlyRole string var friendlyRole string
switch m.Role { switch m.Role {
case MessageRoleUser: case "user":
friendlyRole = "You" friendlyRole = "You"
case MessageRoleSystem: case "system":
friendlyRole = "System" friendlyRole = "System"
case MessageRoleAssistant: case "assistant":
friendlyRole = "Assistant" friendlyRole = "Assistant"
default: default:
friendlyRole = string(m.Role) friendlyRole = m.Role
} }
return friendlyRole return friendlyRole
} }
func (c *Conversation) GenerateTitle() error { func (c *Conversation) GenerateTitle() error {
const header = "Generate a consise 4-5 word title for the conversation below." const header = "Generate a consise 4-5 word title for the conversation below."
prompt := fmt.Sprintf("%s\n\n---\n\n%s", header, c.FormatForExternalPrompting(false)) prompt := fmt.Sprintf("%s\n\n---\n\n%s", header, c.FormatForExternalPrompting())
messages := []Message{ messages := []Message{
{ {
Role: MessageRoleUser, Role: "user",
OriginalContent: prompt, OriginalContent: prompt,
}, },
} }
model := "gpt-3.5-turbo" // use cheap model to generate title model := "gpt-3.5-turbo" // use cheap model to generate title
response, err := CreateChatCompletion(model, messages, 25, nil) response, err := CreateChatCompletion(model, messages, 25)
if err != nil { if err != nil {
return err return err
} }
@ -50,16 +42,13 @@ func (c *Conversation) GenerateTitle() error {
return nil return nil
} }
func (c *Conversation) FormatForExternalPrompting(system bool) string { func (c *Conversation) FormatForExternalPrompting() string {
sb := strings.Builder{} sb := strings.Builder{}
messages, err := store.Messages(c) messages, err := store.Messages(c)
if err != nil { if err != nil {
Fatal("Could not retrieve messages for conversation %v", c) Fatal("Could not retrieve messages for conversation %v", c)
} }
for _, message := range messages { for _, message := range messages {
if message.Role == MessageRoleSystem && !system {
continue
}
sb.WriteString(fmt.Sprintf("<%s>\n", message.FriendlyRole())) sb.WriteString(fmt.Sprintf("<%s>\n", message.FriendlyRole()))
sb.WriteString(fmt.Sprintf("\"\"\"\n%s\n\"\"\"\n\n", message.OriginalContent)) sb.WriteString(fmt.Sprintf("\"\"\"\n%s\n\"\"\"\n\n", message.OriginalContent))
} }

View File

@ -1,582 +0,0 @@
package cli
import (
"database/sql"
"encoding/json"
"fmt"
"os"
"path/filepath"
"strings"
openai "github.com/sashabaranov/go-openai"
)
type FunctionResult struct {
Message string `json:"message"`
Result any `json:"result,omitempty"`
}
type FunctionParameter struct {
Type string `json:"type"` // "string", "integer", "boolean"
Description string `json:"description"`
Enum []string `json:"enum,omitempty"`
}
type FunctionParameters struct {
Type string `json:"type"` // "object"
Properties map[string]FunctionParameter `json:"properties"`
Required []string `json:"required,omitempty"` // required function parameter names
}
type AvailableTool struct {
openai.Tool
// The tool's implementation. Returns a string, as tool call results
// are treated as normal messages with string contents.
Impl func(arguments map[string]interface{}) (string, error)
}
const (
READ_DIR_DESCRIPTION = `Return the contents of the CWD (current working directory).
Results are returned as JSON in the following format:
{
"message": "success", // if successful, or a different message indicating failure
// result may be an empty array if there are no files in the directory
"result": [
{"name": "a_file", "type": "file", "size": 123},
{"name": "a_directory/", "type": "dir", "size": 11},
... // more files or directories
]
}
For files, size represents the size (in bytes) of the file.
For directories, size represents the number of entries in that directory.`
READ_FILE_DESCRIPTION = `Read the contents of a text file relative to the current working directory.
Each line of the file is prefixed with its line number and a tabs (\t) to make
it make it easier to see which lines to change for other modifications.
Example result:
{
"message": "success", // if successful, or a different message indicating failure
"result": "1\tthe contents\n2\tof the file\n"
}`
WRITE_FILE_DESCRIPTION = `Write the provided contents to a file relative to the current working directory.
Note: only use this tool when you've been explicitly asked to create or write to a file.
When using this function, you do not need to share the content you intend to write with the user first.
Example result:
{
"message": "success", // if successful, or a different message indicating failure
}`
FILE_INSERT_LINES_DESCRIPTION = `Insert lines into a file, must specify path.
Make sure your inserts match the flow and indentation of surrounding content.`
FILE_REPLACE_LINES_DESCRIPTION = `Replace or remove a range of lines within a file, must specify path.
Useful for re-writing snippets/blocks of code or entire functions.
Be cautious with your edits. When replacing, ensure the replacement content matches the flow and indentation of surrounding content.`
)
var AvailableTools = map[string]AvailableTool{
"read_dir": {
Tool: openai.Tool{Type: "function", Function: openai.FunctionDefinition{
Name: "read_dir",
Description: READ_DIR_DESCRIPTION,
Parameters: FunctionParameters{
Type: "object",
Properties: map[string]FunctionParameter{
"relative_dir": {
Type: "string",
Description: "If set, read the contents of a directory relative to the current one.",
},
},
},
}},
Impl: func(args map[string]interface{}) (string, error) {
var relativeDir string
tmp, ok := args["relative_dir"]
if ok {
relativeDir, ok = tmp.(string)
if !ok {
return "", fmt.Errorf("Invalid relative_dir in function arguments: %v", tmp)
}
}
return ReadDir(relativeDir), nil
},
},
"read_file": {
Tool: openai.Tool{Type: "function", Function: openai.FunctionDefinition{
Name: "read_file",
Description: READ_FILE_DESCRIPTION,
Parameters: FunctionParameters{
Type: "object",
Properties: map[string]FunctionParameter{
"path": {
Type: "string",
Description: "Path to a file within the current working directory to read.",
},
},
Required: []string{"path"},
},
}},
Impl: func(args map[string]interface{}) (string, error) {
tmp, ok := args["path"]
if !ok {
return "", fmt.Errorf("Path parameter to read_file was not included.")
}
path, ok := tmp.(string)
if !ok {
return "", fmt.Errorf("Invalid path in function arguments: %v", tmp)
}
return ReadFile(path), nil
},
},
"write_file": {
Tool: openai.Tool{Type: "function", Function: openai.FunctionDefinition{
Name: "write_file",
Description: WRITE_FILE_DESCRIPTION,
Parameters: FunctionParameters{
Type: "object",
Properties: map[string]FunctionParameter{
"path": {
Type: "string",
Description: "Path to a file within the current working directory to write to.",
},
"content": {
Type: "string",
Description: "The content to write to the file. Overwrites any existing content!",
},
},
Required: []string{"path", "content"},
},
}},
Impl: func(args map[string]interface{}) (string, error) {
tmp, ok := args["path"]
if !ok {
return "", fmt.Errorf("Path parameter to write_file was not included.")
}
path, ok := tmp.(string)
if !ok {
return "", fmt.Errorf("Invalid path in function arguments: %v", tmp)
}
tmp, ok = args["content"]
if !ok {
return "", fmt.Errorf("Content parameter to write_file was not included.")
}
content, ok := tmp.(string)
if !ok {
return "", fmt.Errorf("Invalid content in function arguments: %v", tmp)
}
return WriteFile(path, content), nil
},
},
"file_insert_lines": {
Tool: openai.Tool{Type: "function", Function: openai.FunctionDefinition{
Name: "file_insert_lines",
Description: FILE_INSERT_LINES_DESCRIPTION,
Parameters: FunctionParameters{
Type: "object",
Properties: map[string]FunctionParameter{
"path": {
Type: "string",
Description: "Path of the file to be modified, relative to the current working directory.",
},
"position": {
Type: "integer",
Description: `Which line to insert content *before*.`,
},
"content": {
Type: "string",
Description: `The content to insert.`,
},
},
Required: []string{"path", "position", "content"},
},
}},
Impl: func(args map[string]interface{}) (string, error) {
tmp, ok := args["path"]
if !ok {
return "", fmt.Errorf("path parameter to write_file was not included.")
}
path, ok := tmp.(string)
if !ok {
return "", fmt.Errorf("Invalid path in function arguments: %v", tmp)
}
var position int
tmp, ok = args["position"]
if ok {
tmp, ok := tmp.(float64)
if !ok {
return "", fmt.Errorf("Invalid position in function arguments: %v", tmp)
}
position = int(tmp)
}
var content string
tmp, ok = args["content"]
if ok {
content, ok = tmp.(string)
if !ok {
return "", fmt.Errorf("Invalid content in function arguments: %v", tmp)
}
}
return FileInsertLines(path, position, content), nil
},
},
"file_replace_lines": {
Tool: openai.Tool{Type: "function", Function: openai.FunctionDefinition{
Name: "file_replace_lines",
Description: FILE_REPLACE_LINES_DESCRIPTION,
Parameters: FunctionParameters{
Type: "object",
Properties: map[string]FunctionParameter{
"path": {
Type: "string",
Description: "Path of the file to be modified, relative to the current working directory.",
},
"start_line": {
Type: "integer",
Description: `Line number which specifies the start of the replacement range (inclusive).`,
},
"end_line": {
Type: "integer",
Description: `Line number which specifies the end of the replacement range (inclusive). If unset, range extends to end of file.`,
},
"content": {
Type: "string",
Description: `Content to replace specified range. Omit to remove the specified range.`,
},
},
Required: []string{"path", "start_line"},
},
}},
Impl: func(args map[string]interface{}) (string, error) {
tmp, ok := args["path"]
if !ok {
return "", fmt.Errorf("path parameter to write_file was not included.")
}
path, ok := tmp.(string)
if !ok {
return "", fmt.Errorf("Invalid path in function arguments: %v", tmp)
}
var start_line int
tmp, ok = args["start_line"]
if ok {
tmp, ok := tmp.(float64)
if !ok {
return "", fmt.Errorf("Invalid start_line in function arguments: %v", tmp)
}
start_line = int(tmp)
}
var end_line int
tmp, ok = args["end_line"]
if ok {
tmp, ok := tmp.(float64)
if !ok {
return "", fmt.Errorf("Invalid end_line in function arguments: %v", tmp)
}
end_line = int(tmp)
}
var content string
tmp, ok = args["content"]
if ok {
content, ok = tmp.(string)
if !ok {
return "", fmt.Errorf("Invalid content in function arguments: %v", tmp)
}
}
return FileReplaceLines(path, start_line, end_line, content), nil
},
},
}
func resultToJson(result FunctionResult) string {
if result.Message == "" {
// When message not supplied, assume success
result.Message = "success"
}
jsonBytes, err := json.Marshal(result)
if err != nil {
fmt.Printf("Could not marshal FunctionResult to JSON: %v\n", err)
}
return string(jsonBytes)
}
// ExecuteToolCalls handles the execution of all tool_calls provided, and
// returns their results formatted as []Message(s) with role: 'tool' and.
func ExecuteToolCalls(toolCalls []openai.ToolCall) ([]Message, error) {
var toolResults []Message
for _, toolCall := range toolCalls {
if toolCall.Type != "function" {
// unsupported tool type
continue
}
tool, ok := AvailableTools[toolCall.Function.Name]
if !ok {
return nil, fmt.Errorf("Requested tool '%s' does not exist. Hallucination?", toolCall.Function.Name)
}
var functionArgs map[string]interface{}
err := json.Unmarshal([]byte(toolCall.Function.Arguments), &functionArgs)
if err != nil {
return nil, fmt.Errorf("Could not unmarshal tool arguments. Malformed JSON? Error: %v", err)
}
// TODO: ability to silence this
fmt.Fprintf(os.Stderr, "INFO: Executing tool '%s' with args %s\n", toolCall.Function.Name, toolCall.Function.Arguments)
// Execute the tool
toolResult, err := tool.Impl(functionArgs)
if err != nil {
// This can happen if the model missed or supplied invalid tool args
return nil, fmt.Errorf("Tool '%s' error: %v\n", toolCall.Function.Name, err)
}
toolResults = append(toolResults, Message{
Role: "tool",
OriginalContent: toolResult,
ToolCallID: sql.NullString{String: toolCall.ID, Valid: true},
// name is not required since the introduction of ToolCallID
// hypothesis: by setting it, we inform the model of what a
// function's purpose was if future requests omit the function
// definition
})
}
return toolResults, nil
}
// isPathContained attempts to verify whether `path` is the same as or
// contained within `directory`. It is overly cautious, returning false even if
// `path` IS contained within `directory`, but the two paths use different
// casing, and we happen to be on a case-insensitive filesystem.
// This is ultimately to attempt to stop an LLM from going outside of where I
// tell it to. Additional layers of security should be considered.. run in a
// VM/container.
func isPathContained(directory string, path string) (bool, error) {
// Clean and resolve symlinks for both paths
path, err := filepath.Abs(path)
if err != nil {
return false, err
}
// check if path exists
_, err = os.Stat(path)
if err != nil {
if !os.IsNotExist(err) {
return false, fmt.Errorf("Could not stat path: %v", err)
}
} else {
path, err = filepath.EvalSymlinks(path)
if err != nil {
return false, err
}
}
directory, err = filepath.Abs(directory)
if err != nil {
return false, err
}
directory, err = filepath.EvalSymlinks(directory)
if err != nil {
return false, err
}
// Case insensitive checks
if !strings.EqualFold(path, directory) &&
!strings.HasPrefix(strings.ToLower(path), strings.ToLower(directory)+string(os.PathSeparator)) {
return false, nil
}
return true, nil
}
func isPathWithinCWD(path string) (bool, *FunctionResult) {
cwd, err := os.Getwd()
if err != nil {
return false, &FunctionResult{Message: "Failed to determine current working directory"}
}
if ok, err := isPathContained(cwd, path); !ok {
if err != nil {
return false, &FunctionResult{Message: fmt.Sprintf("Could not determine whether path '%s' is within the current working directory: %s", path, err.Error())}
}
return false, &FunctionResult{Message: fmt.Sprintf("Path '%s' is not within the current working directory", path)}
}
return true, nil
}
func ReadDir(path string) string {
// TODO(?): implement whitelist - list of directories which model is allowed to work in
if path == "" {
path = "."
}
ok, res := isPathWithinCWD(path)
if !ok {
return resultToJson(*res)
}
files, err := os.ReadDir(path)
if err != nil {
return resultToJson(FunctionResult{
Message: err.Error(),
})
}
var dirContents []map[string]interface{}
for _, f := range files {
info, _ := f.Info()
name := f.Name()
if strings.HasPrefix(name, ".") {
// skip hidden files
continue
}
entryType := "file"
size := info.Size()
if info.IsDir() {
name += "/"
entryType = "dir"
subdirfiles, _ := os.ReadDir(filepath.Join(".", path, info.Name()))
size = int64(len(subdirfiles))
}
dirContents = append(dirContents, map[string]interface{}{
"name": name,
"type": entryType,
"size": size,
})
}
return resultToJson(FunctionResult{Result: dirContents})
}
func ReadFile(path string) string {
ok, res := isPathWithinCWD(path)
if !ok {
return resultToJson(*res)
}
data, err := os.ReadFile(path)
if err != nil {
return resultToJson(FunctionResult{Message: fmt.Sprintf("Could not read path: %s", err.Error())})
}
lines := strings.Split(string(data), "\n")
content := strings.Builder{}
for i, line := range lines {
content.WriteString(fmt.Sprintf("%d\t%s\n", i+1, line))
}
return resultToJson(FunctionResult{
Result: content.String(),
})
}
func WriteFile(path string, content string) string {
ok, res := isPathWithinCWD(path)
if !ok {
return resultToJson(*res)
}
err := os.WriteFile(path, []byte(content), 0644)
if err != nil {
return resultToJson(FunctionResult{Message: fmt.Sprintf("Could not write to path: %s", err.Error())})
}
return resultToJson(FunctionResult{})
}
func FileInsertLines(path string, position int, content string) string {
ok, res := isPathWithinCWD(path)
if !ok {
return resultToJson(*res)
}
// Read the existing file's content
data, err := os.ReadFile(path)
if err != nil {
if !os.IsNotExist(err) {
return resultToJson(FunctionResult{Message: fmt.Sprintf("Could not read path: %s", err.Error())})
}
_, err = os.Create(path)
if err != nil {
return resultToJson(FunctionResult{Message: fmt.Sprintf("Could not create new file: %s", err.Error())})
}
data = []byte{}
}
if position < 1 {
return resultToJson(FunctionResult{Message: "start_line cannot be less than 1"})
}
lines := strings.Split(string(data), "\n")
contentLines := strings.Split(strings.Trim(content, "\n"), "\n")
before := lines[:position-1]
after := lines[position-1:]
lines = append(before, append(contentLines, after...)...)
newContent := strings.Join(lines, "\n")
// Join the lines and write back to the file
err = os.WriteFile(path, []byte(newContent), 0644)
if err != nil {
return resultToJson(FunctionResult{Message: fmt.Sprintf("Could not write to path: %s", err.Error())})
}
return resultToJson(FunctionResult{Result: newContent})
}
func FileReplaceLines(path string, startLine int, endLine int, content string) string {
ok, res := isPathWithinCWD(path)
if !ok {
return resultToJson(*res)
}
// Read the existing file's content
data, err := os.ReadFile(path)
if err != nil {
if !os.IsNotExist(err) {
return resultToJson(FunctionResult{Message: fmt.Sprintf("Could not read path: %s", err.Error())})
}
_, err = os.Create(path)
if err != nil {
return resultToJson(FunctionResult{Message: fmt.Sprintf("Could not create new file: %s", err.Error())})
}
data = []byte{}
}
if startLine < 1 {
return resultToJson(FunctionResult{Message: "start_line cannot be less than 1"})
}
lines := strings.Split(string(data), "\n")
contentLines := strings.Split(strings.Trim(content, "\n"), "\n")
if endLine == 0 || endLine > len(lines) {
endLine = len(lines)
}
before := lines[:startLine-1]
after := lines[endLine:]
lines = append(before, append(contentLines, after...)...)
newContent := strings.Join(lines, "\n")
// Join the lines and write back to the file
err = os.WriteFile(path, []byte(newContent), 0644)
if err != nil {
return resultToJson(FunctionResult{Message: fmt.Sprintf("Could not write to path: %s", err.Error())})
}
return resultToJson(FunctionResult{Result: newContent})
}

View File

@ -2,10 +2,7 @@ package cli
import ( import (
"context" "context"
"database/sql"
"encoding/json"
"errors" "errors"
"fmt"
"io" "io"
"strings" "strings"
@ -15,52 +12,23 @@ import (
func CreateChatCompletionRequest(model string, messages []Message, maxTokens int) openai.ChatCompletionRequest { func CreateChatCompletionRequest(model string, messages []Message, maxTokens int) openai.ChatCompletionRequest {
chatCompletionMessages := []openai.ChatCompletionMessage{} chatCompletionMessages := []openai.ChatCompletionMessage{}
for _, m := range messages { for _, m := range messages {
message := openai.ChatCompletionMessage{ chatCompletionMessages = append(chatCompletionMessages, openai.ChatCompletionMessage{
Role: string(m.Role), Role: m.Role,
Content: m.OriginalContent, Content: m.OriginalContent,
} })
if m.ToolCallID.Valid {
message.ToolCallID = m.ToolCallID.String
}
if m.ToolCalls.Valid {
// unmarshal directly into chatMessage.ToolCalls
err := json.Unmarshal([]byte(m.ToolCalls.String), &message.ToolCalls)
if err != nil {
// TODO: handle, this shouldn't really happen since
// we only save the successfully marshal'd data to database
fmt.Printf("Error unmarshalling the tool_calls JSON: %v\n", err)
}
}
chatCompletionMessages = append(chatCompletionMessages, message)
} }
request := openai.ChatCompletionRequest{ return openai.ChatCompletionRequest{
Model: model, Model: model,
Messages: chatCompletionMessages, Messages: chatCompletionMessages,
MaxTokens: maxTokens, MaxTokens: maxTokens,
N: 1, // limit responses to 1 "choice". we use choices[0] to reference it N: 1, // limit responses to 1 "choice". we use choices[0] to reference it
} }
var tools []openai.Tool
for _, t := range config.OpenAI.EnabledTools {
tool, ok := AvailableTools[t]
if ok {
tools = append(tools, tool.Tool)
}
}
if len(tools) > 0 {
request.Tools = tools
request.ToolChoice = "auto"
}
return request
} }
// CreateChatCompletion submits a Chat Completion API request and returns the // CreateChatCompletion submits a Chat Completion API request and returns the
// response. CreateChatCompletion will recursively call itself in the case of // response.
// tool calls, until a response is received with the final user-facing output. func CreateChatCompletion(model string, messages []Message, maxTokens int) (string, error) {
func CreateChatCompletion(model string, messages []Message, maxTokens int, replies *[]Message) (string, error) {
client := openai.NewClient(*config.OpenAI.APIKey) client := openai.NewClient(*config.OpenAI.APIKey)
req := CreateChatCompletionRequest(model, messages, maxTokens) req := CreateChatCompletionRequest(model, messages, maxTokens)
resp, err := client.CreateChatCompletion(context.Background(), req) resp, err := client.CreateChatCompletion(context.Background(), req)
@ -68,46 +36,13 @@ func CreateChatCompletion(model string, messages []Message, maxTokens int, repli
return "", err return "", err
} }
choice := resp.Choices[0] return resp.Choices[0].Message.Content, nil
if len(choice.Message.ToolCalls) > 0 {
// Append the assistant's reply with its request for tool calls
toolCallJson, _ := json.Marshal(choice.Message.ToolCalls)
assistantReply := Message{
Role: MessageRoleAssistant,
ToolCalls: sql.NullString{String: string(toolCallJson), Valid: true},
}
toolReplies, err := ExecuteToolCalls(choice.Message.ToolCalls)
if err != nil {
return "", err
}
if replies != nil {
*replies = append(append(*replies, assistantReply), toolReplies...)
}
messages = append(append(messages, assistantReply), toolReplies...)
// Recurse into CreateChatCompletion with the tool call replies added
// to the original messages
return CreateChatCompletion(model, append(messages, toolReplies...), maxTokens, replies)
}
if replies != nil {
*replies = append(*replies, Message{
Role: MessageRoleAssistant,
OriginalContent: choice.Message.Content,
})
}
// Return the user-facing message.
return choice.Message.Content, nil
} }
// CreateChatCompletionStream submits a streaming Chat Completion API request // CreateChatCompletionStream submits a streaming Chat Completion API request
// and both returns and streams the response to the provided output channel. // and both returns and streams the response to the provided output channel.
// May return a partial response if an error occurs mid-stream. // May return a partial response if an error occurs mid-stream.
func CreateChatCompletionStream(model string, messages []Message, maxTokens int, output chan<- string, replies *[]Message) (string, error) { func CreateChatCompletionStream(model string, messages []Message, maxTokens int, output chan<- string) (string, error) {
client := openai.NewClient(*config.OpenAI.APIKey) client := openai.NewClient(*config.OpenAI.APIKey)
req := CreateChatCompletionRequest(model, messages, maxTokens) req := CreateChatCompletionRequest(model, messages, maxTokens)
@ -117,10 +52,7 @@ func CreateChatCompletionStream(model string, messages []Message, maxTokens int,
} }
defer stream.Close() defer stream.Close()
content := strings.Builder{} sb := strings.Builder{}
toolCalls := []openai.ToolCall{}
// Iterate stream segments
for { for {
response, e := stream.Recv() response, e := stream.Recv()
if errors.Is(e, io.EOF) { if errors.Is(e, io.EOF) {
@ -131,57 +63,9 @@ func CreateChatCompletionStream(model string, messages []Message, maxTokens int,
err = e err = e
break break
} }
chunk := response.Choices[0].Delta.Content
delta := response.Choices[0].Delta output <- chunk
if len(delta.ToolCalls) > 0 { sb.WriteString(chunk)
// Construct streamed tool_call arguments
for _, tc := range delta.ToolCalls {
if tc.Index == nil {
return "", fmt.Errorf("Unexpected nil index for streamed tool call.")
} }
if len(toolCalls) <= *tc.Index { return sb.String(), err
toolCalls = append(toolCalls, tc)
} else {
toolCalls[*tc.Index].Function.Arguments += tc.Function.Arguments
}
}
} else {
output <- delta.Content
content.WriteString(delta.Content)
}
}
if len(toolCalls) > 0 {
// Append the assistant's reply with its request for tool calls
toolCallJson, _ := json.Marshal(toolCalls)
assistantReply := Message{
Role: MessageRoleAssistant,
OriginalContent: content.String(),
ToolCalls: sql.NullString{String: string(toolCallJson), Valid: true},
}
toolReplies, err := ExecuteToolCalls(toolCalls)
if err != nil {
return "", err
}
if replies != nil {
*replies = append(append(*replies, assistantReply), toolReplies...)
}
// Recurse into CreateChatCompletionStream with the tool call replies
// added to the original messages
messages = append(append(messages, assistantReply), toolReplies...)
return CreateChatCompletionStream(model, messages, maxTokens, output, replies)
}
if replies != nil {
*replies = append(*replies, Message{
Role: MessageRoleAssistant,
OriginalContent: content.String(),
})
}
return content.String(), err
} }

View File

@ -23,10 +23,8 @@ type Message struct {
ConversationID uint `gorm:"foreignKey:ConversationID"` ConversationID uint `gorm:"foreignKey:ConversationID"`
Conversation Conversation Conversation Conversation
OriginalContent string OriginalContent string
Role MessageRole // one of: 'system', 'user', 'assistant', 'tool' Role string // 'user' or 'assistant'
CreatedAt time.Time CreatedAt time.Time
ToolCallID sql.NullString
ToolCalls sql.NullString // a json-encoded array of tool calls from the model
} }
type Conversation struct { type Conversation struct {
@ -97,10 +95,6 @@ func (s *Store) SaveMessage(message *Message) error {
return s.db.Create(message).Error return s.db.Create(message).Error
} }
func (s *Store) DeleteMessage(message *Message) error {
return s.db.Delete(&message).Error
}
func (s *Store) Conversations() ([]Conversation, error) { func (s *Store) Conversations() ([]Conversation, error) {
var conversations []Conversation var conversations []Conversation
err := s.db.Find(&conversations).Error err := s.db.Find(&conversations).Error