Compare commits
6 Commits
1e63c09907
...
d32e9421fe
Author | SHA1 | Date | |
---|---|---|---|
d32e9421fe | |||
e29dbaf2a3 | |||
c64bc370f4 | |||
4f37ed046b | |||
ed6ee9bea9 | |||
e850c340b7 |
111
pkg/cli/cmd.go
111
pkg/cli/cmd.go
@ -53,9 +53,10 @@ func SystemPrompt() string {
|
|||||||
return systemPrompt
|
return systemPrompt
|
||||||
}
|
}
|
||||||
|
|
||||||
// LLMRequest prompts the LLM with the given Message, writes the (partial)
|
// LLMRequest prompts the LLM with the given messages, writing the response
|
||||||
// response to stdout, and returns the (partial) response or any errors.
|
// to stdout. Returns all reply messages added by the LLM, including any
|
||||||
func LLMRequest(messages []Message) (string, error) {
|
// function call messages.
|
||||||
|
func LLMRequest(messages []Message) ([]Message, error) {
|
||||||
// receiver receives the reponse from LLM
|
// receiver receives the reponse from LLM
|
||||||
receiver := make(chan string)
|
receiver := make(chan string)
|
||||||
defer close(receiver)
|
defer close(receiver)
|
||||||
@ -63,7 +64,8 @@ func LLMRequest(messages []Message) (string, error) {
|
|||||||
// start HandleDelayedContent goroutine to print received data to stdout
|
// start HandleDelayedContent goroutine to print received data to stdout
|
||||||
go HandleDelayedContent(receiver)
|
go HandleDelayedContent(receiver)
|
||||||
|
|
||||||
response, err := CreateChatCompletionStream(model, messages, maxTokens, receiver)
|
var replies []Message
|
||||||
|
response, err := CreateChatCompletionStream(model, messages, maxTokens, receiver, &replies)
|
||||||
if response != "" {
|
if response != "" {
|
||||||
if err != nil {
|
if err != nil {
|
||||||
Warn("Received partial response. Error: %v\n", err)
|
Warn("Received partial response. Error: %v\n", err)
|
||||||
@ -73,7 +75,23 @@ func LLMRequest(messages []Message) (string, error) {
|
|||||||
fmt.Println()
|
fmt.Println()
|
||||||
}
|
}
|
||||||
|
|
||||||
return response, err
|
return replies, err
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *Conversation) GenerateAndSaveReplies(messages []Message) {
|
||||||
|
replies, err := LLMRequest(messages)
|
||||||
|
if err != nil {
|
||||||
|
Fatal("Error fetching LLM response: %v\n", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, reply := range replies {
|
||||||
|
reply.ConversationID = c.ID
|
||||||
|
|
||||||
|
err = store.SaveMessage(&reply)
|
||||||
|
if err != nil {
|
||||||
|
Warn("Could not save reply: %v\n", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// InputFromArgsOrEditor returns either the provided input from the args slice
|
// InputFromArgsOrEditor returns either the provided input from the args slice
|
||||||
@ -316,23 +334,9 @@ var replyCmd = &cobra.Command{
|
|||||||
messages = append(messages, userReply)
|
messages = append(messages, userReply)
|
||||||
|
|
||||||
RenderConversation(messages, true)
|
RenderConversation(messages, true)
|
||||||
assistantReply := Message{
|
(&Message{Role: MessageRoleAssistant}).RenderTTY()
|
||||||
ConversationID: conversation.ID,
|
|
||||||
Role: "assistant",
|
|
||||||
}
|
|
||||||
assistantReply.RenderTTY()
|
|
||||||
|
|
||||||
response, err := LLMRequest(messages)
|
conversation.GenerateAndSaveReplies(messages)
|
||||||
if err != nil {
|
|
||||||
Fatal("Error fetching LLM response: %v\n", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
assistantReply.OriginalContent = response
|
|
||||||
|
|
||||||
err = store.SaveMessage(&assistantReply)
|
|
||||||
if err != nil {
|
|
||||||
Fatal("Could not save assistant reply: %v\n", err)
|
|
||||||
}
|
|
||||||
},
|
},
|
||||||
ValidArgsFunction: func(cmd *cobra.Command, args []string, toComplete string) ([]string, cobra.ShellCompDirective) {
|
ValidArgsFunction: func(cmd *cobra.Command, args []string, toComplete string) ([]string, cobra.ShellCompDirective) {
|
||||||
compMode := cobra.ShellCompDirectiveNoFileComp
|
compMode := cobra.ShellCompDirectiveNoFileComp
|
||||||
@ -379,23 +383,9 @@ var newCmd = &cobra.Command{
|
|||||||
}
|
}
|
||||||
|
|
||||||
RenderConversation(messages, true)
|
RenderConversation(messages, true)
|
||||||
reply := Message{
|
(&Message{Role: MessageRoleAssistant}).RenderTTY()
|
||||||
ConversationID: conversation.ID,
|
|
||||||
Role: "assistant",
|
|
||||||
}
|
|
||||||
reply.RenderTTY()
|
|
||||||
|
|
||||||
response, err := LLMRequest(messages)
|
conversation.GenerateAndSaveReplies(messages)
|
||||||
if err != nil {
|
|
||||||
Fatal("Error fetching LLM response: %v\n", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
reply.OriginalContent = response
|
|
||||||
|
|
||||||
err = store.SaveMessage(&reply)
|
|
||||||
if err != nil {
|
|
||||||
Fatal("Could not save reply: %v\n", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
err = conversation.GenerateTitle()
|
err = conversation.GenerateTitle()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@ -461,33 +451,28 @@ var retryCmd = &cobra.Command{
|
|||||||
}
|
}
|
||||||
|
|
||||||
var lastUserMessageIndex int
|
var lastUserMessageIndex int
|
||||||
|
// walk backwards through conversations to find last user message
|
||||||
for i := len(messages) - 1; i >= 0; i-- {
|
for i := len(messages) - 1; i >= 0; i-- {
|
||||||
if messages[i].Role == "user" {
|
if messages[i].Role == MessageRoleUser {
|
||||||
lastUserMessageIndex = i
|
lastUserMessageIndex = i
|
||||||
break
|
break
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if lastUserMessageIndex == 0 {
|
||||||
|
// haven't found the the last user message yet, delete this one
|
||||||
|
err = store.DeleteMessage(&messages[i])
|
||||||
|
if err != nil {
|
||||||
|
Warn("Could not delete previous reply: %v\n", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
messages = messages[:lastUserMessageIndex+1]
|
messages = messages[:lastUserMessageIndex+1]
|
||||||
|
|
||||||
RenderConversation(messages, true)
|
RenderConversation(messages, true)
|
||||||
assistantReply := Message{
|
(&Message{Role: MessageRoleAssistant}).RenderTTY()
|
||||||
ConversationID: conversation.ID,
|
|
||||||
Role: "assistant",
|
|
||||||
}
|
|
||||||
assistantReply.RenderTTY()
|
|
||||||
|
|
||||||
response, err := LLMRequest(messages)
|
conversation.GenerateAndSaveReplies(messages)
|
||||||
if err != nil {
|
|
||||||
Fatal("Error fetching LLM response: %v\n", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
assistantReply.OriginalContent = response
|
|
||||||
|
|
||||||
err = store.SaveMessage(&assistantReply)
|
|
||||||
if err != nil {
|
|
||||||
Fatal("Could not save assistant reply: %v\n", err)
|
|
||||||
}
|
|
||||||
},
|
},
|
||||||
ValidArgsFunction: func(cmd *cobra.Command, args []string, toComplete string) ([]string, cobra.ShellCompDirective) {
|
ValidArgsFunction: func(cmd *cobra.Command, args []string, toComplete string) ([]string, cobra.ShellCompDirective) {
|
||||||
compMode := cobra.ShellCompDirectiveNoFileComp
|
compMode := cobra.ShellCompDirectiveNoFileComp
|
||||||
@ -522,23 +507,9 @@ var continueCmd = &cobra.Command{
|
|||||||
}
|
}
|
||||||
|
|
||||||
RenderConversation(messages, true)
|
RenderConversation(messages, true)
|
||||||
assistantReply := Message{
|
(&Message{Role: MessageRoleAssistant}).RenderTTY()
|
||||||
ConversationID: conversation.ID,
|
|
||||||
Role: "assistant",
|
|
||||||
}
|
|
||||||
assistantReply.RenderTTY()
|
|
||||||
|
|
||||||
response, err := LLMRequest(messages)
|
conversation.GenerateAndSaveReplies(messages)
|
||||||
if err != nil {
|
|
||||||
Fatal("Error fetching LLM response: %v\n", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
assistantReply.OriginalContent = response
|
|
||||||
|
|
||||||
err = store.SaveMessage(&assistantReply)
|
|
||||||
if err != nil {
|
|
||||||
Fatal("Could not save assistant reply: %v\n", err)
|
|
||||||
}
|
|
||||||
},
|
},
|
||||||
ValidArgsFunction: func(cmd *cobra.Command, args []string, toComplete string) ([]string, cobra.ShellCompDirective) {
|
ValidArgsFunction: func(cmd *cobra.Command, args []string, toComplete string) ([]string, cobra.ShellCompDirective) {
|
||||||
compMode := cobra.ShellCompDirectiveNoFileComp
|
compMode := cobra.ShellCompDirectiveNoFileComp
|
||||||
|
@ -13,9 +13,10 @@ type Config struct {
|
|||||||
SystemPrompt *string `yaml:"systemPrompt" default:"You are a helpful assistant."`
|
SystemPrompt *string `yaml:"systemPrompt" default:"You are a helpful assistant."`
|
||||||
} `yaml:"modelDefaults"`
|
} `yaml:"modelDefaults"`
|
||||||
OpenAI *struct {
|
OpenAI *struct {
|
||||||
APIKey *string `yaml:"apiKey" default:"your_key_here"`
|
APIKey *string `yaml:"apiKey" default:"your_key_here"`
|
||||||
DefaultModel *string `yaml:"defaultModel" default:"gpt-4"`
|
DefaultModel *string `yaml:"defaultModel" default:"gpt-4"`
|
||||||
DefaultMaxLength *int `yaml:"defaultMaxLength" default:"256"`
|
DefaultMaxLength *int `yaml:"defaultMaxLength" default:"256"`
|
||||||
|
EnabledTools []string `yaml:"enabledTools"`
|
||||||
} `yaml:"openai"`
|
} `yaml:"openai"`
|
||||||
Chroma *struct {
|
Chroma *struct {
|
||||||
Style *string `yaml:"style" default:"onedark"`
|
Style *string `yaml:"style" default:"onedark"`
|
||||||
|
@ -5,35 +5,43 @@ import (
|
|||||||
"strings"
|
"strings"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
type MessageRole string
|
||||||
|
|
||||||
|
const (
|
||||||
|
MessageRoleUser MessageRole = "user"
|
||||||
|
MessageRoleAssistant MessageRole = "assistant"
|
||||||
|
MessageRoleSystem MessageRole = "system"
|
||||||
|
)
|
||||||
|
|
||||||
// FriendlyRole returns a human friendly signifier for the message's role.
|
// FriendlyRole returns a human friendly signifier for the message's role.
|
||||||
func (m *Message) FriendlyRole() string {
|
func (m *Message) FriendlyRole() string {
|
||||||
var friendlyRole string
|
var friendlyRole string
|
||||||
switch m.Role {
|
switch m.Role {
|
||||||
case "user":
|
case MessageRoleUser:
|
||||||
friendlyRole = "You"
|
friendlyRole = "You"
|
||||||
case "system":
|
case MessageRoleSystem:
|
||||||
friendlyRole = "System"
|
friendlyRole = "System"
|
||||||
case "assistant":
|
case MessageRoleAssistant:
|
||||||
friendlyRole = "Assistant"
|
friendlyRole = "Assistant"
|
||||||
default:
|
default:
|
||||||
friendlyRole = m.Role
|
friendlyRole = string(m.Role)
|
||||||
}
|
}
|
||||||
return friendlyRole
|
return friendlyRole
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *Conversation) GenerateTitle() error {
|
func (c *Conversation) GenerateTitle() error {
|
||||||
const header = "Generate a consise 4-5 word title for the conversation below."
|
const header = "Generate a consise 4-5 word title for the conversation below."
|
||||||
prompt := fmt.Sprintf("%s\n\n---\n\n%s", header, c.FormatForExternalPrompting())
|
prompt := fmt.Sprintf("%s\n\n---\n\n%s", header, c.FormatForExternalPrompting(false))
|
||||||
|
|
||||||
messages := []Message{
|
messages := []Message{
|
||||||
{
|
{
|
||||||
Role: "user",
|
Role: MessageRoleUser,
|
||||||
OriginalContent: prompt,
|
OriginalContent: prompt,
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
model := "gpt-3.5-turbo" // use cheap model to generate title
|
model := "gpt-3.5-turbo" // use cheap model to generate title
|
||||||
response, err := CreateChatCompletion(model, messages, 25)
|
response, err := CreateChatCompletion(model, messages, 25, nil)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
@ -42,13 +50,16 @@ func (c *Conversation) GenerateTitle() error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *Conversation) FormatForExternalPrompting() string {
|
func (c *Conversation) FormatForExternalPrompting(system bool) string {
|
||||||
sb := strings.Builder{}
|
sb := strings.Builder{}
|
||||||
messages, err := store.Messages(c)
|
messages, err := store.Messages(c)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
Fatal("Could not retrieve messages for conversation %v", c)
|
Fatal("Could not retrieve messages for conversation %v", c)
|
||||||
}
|
}
|
||||||
for _, message := range messages {
|
for _, message := range messages {
|
||||||
|
if message.Role == MessageRoleSystem && !system {
|
||||||
|
continue
|
||||||
|
}
|
||||||
sb.WriteString(fmt.Sprintf("<%s>\n", message.FriendlyRole()))
|
sb.WriteString(fmt.Sprintf("<%s>\n", message.FriendlyRole()))
|
||||||
sb.WriteString(fmt.Sprintf("\"\"\"\n%s\n\"\"\"\n\n", message.OriginalContent))
|
sb.WriteString(fmt.Sprintf("\"\"\"\n%s\n\"\"\"\n\n", message.OriginalContent))
|
||||||
}
|
}
|
||||||
|
582
pkg/cli/functions.go
Normal file
582
pkg/cli/functions.go
Normal file
@ -0,0 +1,582 @@
|
|||||||
|
package cli
|
||||||
|
|
||||||
|
import (
|
||||||
|
"database/sql"
|
||||||
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
|
"os"
|
||||||
|
"path/filepath"
|
||||||
|
"strings"
|
||||||
|
|
||||||
|
openai "github.com/sashabaranov/go-openai"
|
||||||
|
)
|
||||||
|
|
||||||
|
type FunctionResult struct {
|
||||||
|
Message string `json:"message"`
|
||||||
|
Result any `json:"result,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type FunctionParameter struct {
|
||||||
|
Type string `json:"type"` // "string", "integer", "boolean"
|
||||||
|
Description string `json:"description"`
|
||||||
|
Enum []string `json:"enum,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type FunctionParameters struct {
|
||||||
|
Type string `json:"type"` // "object"
|
||||||
|
Properties map[string]FunctionParameter `json:"properties"`
|
||||||
|
Required []string `json:"required,omitempty"` // required function parameter names
|
||||||
|
}
|
||||||
|
|
||||||
|
type AvailableTool struct {
|
||||||
|
openai.Tool
|
||||||
|
// The tool's implementation. Returns a string, as tool call results
|
||||||
|
// are treated as normal messages with string contents.
|
||||||
|
Impl func(arguments map[string]interface{}) (string, error)
|
||||||
|
}
|
||||||
|
|
||||||
|
const (
|
||||||
|
READ_DIR_DESCRIPTION = `Return the contents of the CWD (current working directory).
|
||||||
|
|
||||||
|
Results are returned as JSON in the following format:
|
||||||
|
{
|
||||||
|
"message": "success", // if successful, or a different message indicating failure
|
||||||
|
// result may be an empty array if there are no files in the directory
|
||||||
|
"result": [
|
||||||
|
{"name": "a_file", "type": "file", "size": 123},
|
||||||
|
{"name": "a_directory/", "type": "dir", "size": 11},
|
||||||
|
... // more files or directories
|
||||||
|
]
|
||||||
|
}
|
||||||
|
|
||||||
|
For files, size represents the size (in bytes) of the file.
|
||||||
|
For directories, size represents the number of entries in that directory.`
|
||||||
|
|
||||||
|
READ_FILE_DESCRIPTION = `Read the contents of a text file relative to the current working directory.
|
||||||
|
|
||||||
|
Each line of the file is prefixed with its line number and a tabs (\t) to make
|
||||||
|
it make it easier to see which lines to change for other modifications.
|
||||||
|
|
||||||
|
Example result:
|
||||||
|
{
|
||||||
|
"message": "success", // if successful, or a different message indicating failure
|
||||||
|
"result": "1\tthe contents\n2\tof the file\n"
|
||||||
|
}`
|
||||||
|
|
||||||
|
WRITE_FILE_DESCRIPTION = `Write the provided contents to a file relative to the current working directory.
|
||||||
|
|
||||||
|
Note: only use this tool when you've been explicitly asked to create or write to a file.
|
||||||
|
|
||||||
|
When using this function, you do not need to share the content you intend to write with the user first.
|
||||||
|
|
||||||
|
Example result:
|
||||||
|
{
|
||||||
|
"message": "success", // if successful, or a different message indicating failure
|
||||||
|
}`
|
||||||
|
|
||||||
|
FILE_INSERT_LINES_DESCRIPTION = `Insert lines into a file, must specify path.
|
||||||
|
|
||||||
|
Make sure your inserts match the flow and indentation of surrounding content.`
|
||||||
|
|
||||||
|
FILE_REPLACE_LINES_DESCRIPTION = `Replace or remove a range of lines within a file, must specify path.
|
||||||
|
|
||||||
|
Useful for re-writing snippets/blocks of code or entire functions.
|
||||||
|
|
||||||
|
Be cautious with your edits. When replacing, ensure the replacement content matches the flow and indentation of surrounding content.`
|
||||||
|
)
|
||||||
|
|
||||||
|
var AvailableTools = map[string]AvailableTool{
|
||||||
|
"read_dir": {
|
||||||
|
Tool: openai.Tool{Type: "function", Function: openai.FunctionDefinition{
|
||||||
|
Name: "read_dir",
|
||||||
|
Description: READ_DIR_DESCRIPTION,
|
||||||
|
Parameters: FunctionParameters{
|
||||||
|
Type: "object",
|
||||||
|
Properties: map[string]FunctionParameter{
|
||||||
|
"relative_dir": {
|
||||||
|
Type: "string",
|
||||||
|
Description: "If set, read the contents of a directory relative to the current one.",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}},
|
||||||
|
Impl: func(args map[string]interface{}) (string, error) {
|
||||||
|
var relativeDir string
|
||||||
|
tmp, ok := args["relative_dir"]
|
||||||
|
if ok {
|
||||||
|
relativeDir, ok = tmp.(string)
|
||||||
|
if !ok {
|
||||||
|
return "", fmt.Errorf("Invalid relative_dir in function arguments: %v", tmp)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return ReadDir(relativeDir), nil
|
||||||
|
},
|
||||||
|
},
|
||||||
|
"read_file": {
|
||||||
|
Tool: openai.Tool{Type: "function", Function: openai.FunctionDefinition{
|
||||||
|
Name: "read_file",
|
||||||
|
Description: READ_FILE_DESCRIPTION,
|
||||||
|
Parameters: FunctionParameters{
|
||||||
|
Type: "object",
|
||||||
|
Properties: map[string]FunctionParameter{
|
||||||
|
"path": {
|
||||||
|
Type: "string",
|
||||||
|
Description: "Path to a file within the current working directory to read.",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
Required: []string{"path"},
|
||||||
|
},
|
||||||
|
}},
|
||||||
|
Impl: func(args map[string]interface{}) (string, error) {
|
||||||
|
tmp, ok := args["path"]
|
||||||
|
if !ok {
|
||||||
|
return "", fmt.Errorf("Path parameter to read_file was not included.")
|
||||||
|
}
|
||||||
|
path, ok := tmp.(string)
|
||||||
|
if !ok {
|
||||||
|
return "", fmt.Errorf("Invalid path in function arguments: %v", tmp)
|
||||||
|
}
|
||||||
|
return ReadFile(path), nil
|
||||||
|
},
|
||||||
|
},
|
||||||
|
"write_file": {
|
||||||
|
Tool: openai.Tool{Type: "function", Function: openai.FunctionDefinition{
|
||||||
|
Name: "write_file",
|
||||||
|
Description: WRITE_FILE_DESCRIPTION,
|
||||||
|
Parameters: FunctionParameters{
|
||||||
|
Type: "object",
|
||||||
|
Properties: map[string]FunctionParameter{
|
||||||
|
"path": {
|
||||||
|
Type: "string",
|
||||||
|
Description: "Path to a file within the current working directory to write to.",
|
||||||
|
},
|
||||||
|
"content": {
|
||||||
|
Type: "string",
|
||||||
|
Description: "The content to write to the file. Overwrites any existing content!",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
Required: []string{"path", "content"},
|
||||||
|
},
|
||||||
|
}},
|
||||||
|
Impl: func(args map[string]interface{}) (string, error) {
|
||||||
|
tmp, ok := args["path"]
|
||||||
|
if !ok {
|
||||||
|
return "", fmt.Errorf("Path parameter to write_file was not included.")
|
||||||
|
}
|
||||||
|
path, ok := tmp.(string)
|
||||||
|
if !ok {
|
||||||
|
return "", fmt.Errorf("Invalid path in function arguments: %v", tmp)
|
||||||
|
}
|
||||||
|
tmp, ok = args["content"]
|
||||||
|
if !ok {
|
||||||
|
return "", fmt.Errorf("Content parameter to write_file was not included.")
|
||||||
|
}
|
||||||
|
content, ok := tmp.(string)
|
||||||
|
if !ok {
|
||||||
|
return "", fmt.Errorf("Invalid content in function arguments: %v", tmp)
|
||||||
|
}
|
||||||
|
return WriteFile(path, content), nil
|
||||||
|
},
|
||||||
|
},
|
||||||
|
"file_insert_lines": {
|
||||||
|
Tool: openai.Tool{Type: "function", Function: openai.FunctionDefinition{
|
||||||
|
Name: "file_insert_lines",
|
||||||
|
Description: FILE_INSERT_LINES_DESCRIPTION,
|
||||||
|
Parameters: FunctionParameters{
|
||||||
|
Type: "object",
|
||||||
|
Properties: map[string]FunctionParameter{
|
||||||
|
"path": {
|
||||||
|
Type: "string",
|
||||||
|
Description: "Path of the file to be modified, relative to the current working directory.",
|
||||||
|
},
|
||||||
|
"position": {
|
||||||
|
Type: "integer",
|
||||||
|
Description: `Which line to insert content *before*.`,
|
||||||
|
},
|
||||||
|
"content": {
|
||||||
|
Type: "string",
|
||||||
|
Description: `The content to insert.`,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
Required: []string{"path", "position", "content"},
|
||||||
|
},
|
||||||
|
}},
|
||||||
|
Impl: func(args map[string]interface{}) (string, error) {
|
||||||
|
tmp, ok := args["path"]
|
||||||
|
if !ok {
|
||||||
|
return "", fmt.Errorf("path parameter to write_file was not included.")
|
||||||
|
}
|
||||||
|
path, ok := tmp.(string)
|
||||||
|
if !ok {
|
||||||
|
return "", fmt.Errorf("Invalid path in function arguments: %v", tmp)
|
||||||
|
}
|
||||||
|
var position int
|
||||||
|
tmp, ok = args["position"]
|
||||||
|
if ok {
|
||||||
|
tmp, ok := tmp.(float64)
|
||||||
|
if !ok {
|
||||||
|
return "", fmt.Errorf("Invalid position in function arguments: %v", tmp)
|
||||||
|
}
|
||||||
|
position = int(tmp)
|
||||||
|
}
|
||||||
|
var content string
|
||||||
|
tmp, ok = args["content"]
|
||||||
|
if ok {
|
||||||
|
content, ok = tmp.(string)
|
||||||
|
if !ok {
|
||||||
|
return "", fmt.Errorf("Invalid content in function arguments: %v", tmp)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return FileInsertLines(path, position, content), nil
|
||||||
|
},
|
||||||
|
},
|
||||||
|
"file_replace_lines": {
|
||||||
|
Tool: openai.Tool{Type: "function", Function: openai.FunctionDefinition{
|
||||||
|
Name: "file_replace_lines",
|
||||||
|
Description: FILE_REPLACE_LINES_DESCRIPTION,
|
||||||
|
Parameters: FunctionParameters{
|
||||||
|
Type: "object",
|
||||||
|
Properties: map[string]FunctionParameter{
|
||||||
|
"path": {
|
||||||
|
Type: "string",
|
||||||
|
Description: "Path of the file to be modified, relative to the current working directory.",
|
||||||
|
},
|
||||||
|
"start_line": {
|
||||||
|
Type: "integer",
|
||||||
|
Description: `Line number which specifies the start of the replacement range (inclusive).`,
|
||||||
|
},
|
||||||
|
"end_line": {
|
||||||
|
Type: "integer",
|
||||||
|
Description: `Line number which specifies the end of the replacement range (inclusive). If unset, range extends to end of file.`,
|
||||||
|
},
|
||||||
|
"content": {
|
||||||
|
Type: "string",
|
||||||
|
Description: `Content to replace specified range. Omit to remove the specified range.`,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
Required: []string{"path", "start_line"},
|
||||||
|
},
|
||||||
|
}},
|
||||||
|
Impl: func(args map[string]interface{}) (string, error) {
|
||||||
|
tmp, ok := args["path"]
|
||||||
|
if !ok {
|
||||||
|
return "", fmt.Errorf("path parameter to write_file was not included.")
|
||||||
|
}
|
||||||
|
path, ok := tmp.(string)
|
||||||
|
if !ok {
|
||||||
|
return "", fmt.Errorf("Invalid path in function arguments: %v", tmp)
|
||||||
|
}
|
||||||
|
var start_line int
|
||||||
|
tmp, ok = args["start_line"]
|
||||||
|
if ok {
|
||||||
|
tmp, ok := tmp.(float64)
|
||||||
|
if !ok {
|
||||||
|
return "", fmt.Errorf("Invalid start_line in function arguments: %v", tmp)
|
||||||
|
}
|
||||||
|
start_line = int(tmp)
|
||||||
|
}
|
||||||
|
var end_line int
|
||||||
|
tmp, ok = args["end_line"]
|
||||||
|
if ok {
|
||||||
|
tmp, ok := tmp.(float64)
|
||||||
|
if !ok {
|
||||||
|
return "", fmt.Errorf("Invalid end_line in function arguments: %v", tmp)
|
||||||
|
}
|
||||||
|
end_line = int(tmp)
|
||||||
|
}
|
||||||
|
var content string
|
||||||
|
tmp, ok = args["content"]
|
||||||
|
if ok {
|
||||||
|
content, ok = tmp.(string)
|
||||||
|
if !ok {
|
||||||
|
return "", fmt.Errorf("Invalid content in function arguments: %v", tmp)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return FileReplaceLines(path, start_line, end_line, content), nil
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
func resultToJson(result FunctionResult) string {
|
||||||
|
if result.Message == "" {
|
||||||
|
// When message not supplied, assume success
|
||||||
|
result.Message = "success"
|
||||||
|
}
|
||||||
|
|
||||||
|
jsonBytes, err := json.Marshal(result)
|
||||||
|
if err != nil {
|
||||||
|
fmt.Printf("Could not marshal FunctionResult to JSON: %v\n", err)
|
||||||
|
}
|
||||||
|
return string(jsonBytes)
|
||||||
|
}
|
||||||
|
|
||||||
|
// ExecuteToolCalls handles the execution of all tool_calls provided, and
|
||||||
|
// returns their results formatted as []Message(s) with role: 'tool' and.
|
||||||
|
func ExecuteToolCalls(toolCalls []openai.ToolCall) ([]Message, error) {
|
||||||
|
var toolResults []Message
|
||||||
|
for _, toolCall := range toolCalls {
|
||||||
|
if toolCall.Type != "function" {
|
||||||
|
// unsupported tool type
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
tool, ok := AvailableTools[toolCall.Function.Name]
|
||||||
|
if !ok {
|
||||||
|
return nil, fmt.Errorf("Requested tool '%s' does not exist. Hallucination?", toolCall.Function.Name)
|
||||||
|
}
|
||||||
|
|
||||||
|
var functionArgs map[string]interface{}
|
||||||
|
err := json.Unmarshal([]byte(toolCall.Function.Arguments), &functionArgs)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("Could not unmarshal tool arguments. Malformed JSON? Error: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// TODO: ability to silence this
|
||||||
|
fmt.Fprintf(os.Stderr, "INFO: Executing tool '%s' with args %s\n", toolCall.Function.Name, toolCall.Function.Arguments)
|
||||||
|
|
||||||
|
// Execute the tool
|
||||||
|
toolResult, err := tool.Impl(functionArgs)
|
||||||
|
if err != nil {
|
||||||
|
// This can happen if the model missed or supplied invalid tool args
|
||||||
|
return nil, fmt.Errorf("Tool '%s' error: %v\n", toolCall.Function.Name, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
toolResults = append(toolResults, Message{
|
||||||
|
Role: "tool",
|
||||||
|
OriginalContent: toolResult,
|
||||||
|
ToolCallID: sql.NullString{String: toolCall.ID, Valid: true},
|
||||||
|
// name is not required since the introduction of ToolCallID
|
||||||
|
// hypothesis: by setting it, we inform the model of what a
|
||||||
|
// function's purpose was if future requests omit the function
|
||||||
|
// definition
|
||||||
|
})
|
||||||
|
}
|
||||||
|
return toolResults, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// isPathContained attempts to verify whether `path` is the same as or
|
||||||
|
// contained within `directory`. It is overly cautious, returning false even if
|
||||||
|
// `path` IS contained within `directory`, but the two paths use different
|
||||||
|
// casing, and we happen to be on a case-insensitive filesystem.
|
||||||
|
// This is ultimately to attempt to stop an LLM from going outside of where I
|
||||||
|
// tell it to. Additional layers of security should be considered.. run in a
|
||||||
|
// VM/container.
|
||||||
|
func isPathContained(directory string, path string) (bool, error) {
|
||||||
|
// Clean and resolve symlinks for both paths
|
||||||
|
path, err := filepath.Abs(path)
|
||||||
|
if err != nil {
|
||||||
|
return false, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// check if path exists
|
||||||
|
_, err = os.Stat(path)
|
||||||
|
if err != nil {
|
||||||
|
if !os.IsNotExist(err) {
|
||||||
|
return false, fmt.Errorf("Could not stat path: %v", err)
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
path, err = filepath.EvalSymlinks(path)
|
||||||
|
if err != nil {
|
||||||
|
return false, err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
directory, err = filepath.Abs(directory)
|
||||||
|
if err != nil {
|
||||||
|
return false, err
|
||||||
|
}
|
||||||
|
directory, err = filepath.EvalSymlinks(directory)
|
||||||
|
if err != nil {
|
||||||
|
return false, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// Case insensitive checks
|
||||||
|
if !strings.EqualFold(path, directory) &&
|
||||||
|
!strings.HasPrefix(strings.ToLower(path), strings.ToLower(directory)+string(os.PathSeparator)) {
|
||||||
|
return false, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
return true, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func isPathWithinCWD(path string) (bool, *FunctionResult) {
|
||||||
|
cwd, err := os.Getwd()
|
||||||
|
if err != nil {
|
||||||
|
return false, &FunctionResult{Message: "Failed to determine current working directory"}
|
||||||
|
}
|
||||||
|
if ok, err := isPathContained(cwd, path); !ok {
|
||||||
|
if err != nil {
|
||||||
|
return false, &FunctionResult{Message: fmt.Sprintf("Could not determine whether path '%s' is within the current working directory: %s", path, err.Error())}
|
||||||
|
}
|
||||||
|
return false, &FunctionResult{Message: fmt.Sprintf("Path '%s' is not within the current working directory", path)}
|
||||||
|
}
|
||||||
|
return true, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func ReadDir(path string) string {
|
||||||
|
// TODO(?): implement whitelist - list of directories which model is allowed to work in
|
||||||
|
if path == "" {
|
||||||
|
path = "."
|
||||||
|
}
|
||||||
|
ok, res := isPathWithinCWD(path)
|
||||||
|
if !ok {
|
||||||
|
return resultToJson(*res)
|
||||||
|
}
|
||||||
|
|
||||||
|
files, err := os.ReadDir(path)
|
||||||
|
if err != nil {
|
||||||
|
return resultToJson(FunctionResult{
|
||||||
|
Message: err.Error(),
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
var dirContents []map[string]interface{}
|
||||||
|
for _, f := range files {
|
||||||
|
info, _ := f.Info()
|
||||||
|
|
||||||
|
name := f.Name()
|
||||||
|
if strings.HasPrefix(name, ".") {
|
||||||
|
// skip hidden files
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
entryType := "file"
|
||||||
|
size := info.Size()
|
||||||
|
|
||||||
|
if info.IsDir() {
|
||||||
|
name += "/"
|
||||||
|
entryType = "dir"
|
||||||
|
subdirfiles, _ := os.ReadDir(filepath.Join(".", path, info.Name()))
|
||||||
|
size = int64(len(subdirfiles))
|
||||||
|
}
|
||||||
|
|
||||||
|
dirContents = append(dirContents, map[string]interface{}{
|
||||||
|
"name": name,
|
||||||
|
"type": entryType,
|
||||||
|
"size": size,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
return resultToJson(FunctionResult{Result: dirContents})
|
||||||
|
}
|
||||||
|
|
||||||
|
func ReadFile(path string) string {
|
||||||
|
ok, res := isPathWithinCWD(path)
|
||||||
|
if !ok {
|
||||||
|
return resultToJson(*res)
|
||||||
|
}
|
||||||
|
data, err := os.ReadFile(path)
|
||||||
|
if err != nil {
|
||||||
|
return resultToJson(FunctionResult{Message: fmt.Sprintf("Could not read path: %s", err.Error())})
|
||||||
|
}
|
||||||
|
|
||||||
|
lines := strings.Split(string(data), "\n")
|
||||||
|
content := strings.Builder{}
|
||||||
|
for i, line := range lines {
|
||||||
|
content.WriteString(fmt.Sprintf("%d\t%s\n", i+1, line))
|
||||||
|
}
|
||||||
|
|
||||||
|
return resultToJson(FunctionResult{
|
||||||
|
Result: content.String(),
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func WriteFile(path string, content string) string {
|
||||||
|
ok, res := isPathWithinCWD(path)
|
||||||
|
if !ok {
|
||||||
|
return resultToJson(*res)
|
||||||
|
}
|
||||||
|
err := os.WriteFile(path, []byte(content), 0644)
|
||||||
|
if err != nil {
|
||||||
|
return resultToJson(FunctionResult{Message: fmt.Sprintf("Could not write to path: %s", err.Error())})
|
||||||
|
}
|
||||||
|
return resultToJson(FunctionResult{})
|
||||||
|
}
|
||||||
|
|
||||||
|
func FileInsertLines(path string, position int, content string) string {
|
||||||
|
ok, res := isPathWithinCWD(path)
|
||||||
|
if !ok {
|
||||||
|
return resultToJson(*res)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Read the existing file's content
|
||||||
|
data, err := os.ReadFile(path)
|
||||||
|
if err != nil {
|
||||||
|
if !os.IsNotExist(err) {
|
||||||
|
return resultToJson(FunctionResult{Message: fmt.Sprintf("Could not read path: %s", err.Error())})
|
||||||
|
}
|
||||||
|
_, err = os.Create(path)
|
||||||
|
if err != nil {
|
||||||
|
return resultToJson(FunctionResult{Message: fmt.Sprintf("Could not create new file: %s", err.Error())})
|
||||||
|
}
|
||||||
|
data = []byte{}
|
||||||
|
}
|
||||||
|
|
||||||
|
if position < 1 {
|
||||||
|
return resultToJson(FunctionResult{Message: "start_line cannot be less than 1"})
|
||||||
|
}
|
||||||
|
|
||||||
|
lines := strings.Split(string(data), "\n")
|
||||||
|
contentLines := strings.Split(strings.Trim(content, "\n"), "\n")
|
||||||
|
|
||||||
|
before := lines[:position-1]
|
||||||
|
after := lines[position-1:]
|
||||||
|
lines = append(before, append(contentLines, after...)...)
|
||||||
|
|
||||||
|
newContent := strings.Join(lines, "\n")
|
||||||
|
|
||||||
|
// Join the lines and write back to the file
|
||||||
|
err = os.WriteFile(path, []byte(newContent), 0644)
|
||||||
|
if err != nil {
|
||||||
|
return resultToJson(FunctionResult{Message: fmt.Sprintf("Could not write to path: %s", err.Error())})
|
||||||
|
}
|
||||||
|
|
||||||
|
return resultToJson(FunctionResult{Result: newContent})
|
||||||
|
}
|
||||||
|
|
||||||
|
func FileReplaceLines(path string, startLine int, endLine int, content string) string {
|
||||||
|
ok, res := isPathWithinCWD(path)
|
||||||
|
if !ok {
|
||||||
|
return resultToJson(*res)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Read the existing file's content
|
||||||
|
data, err := os.ReadFile(path)
|
||||||
|
if err != nil {
|
||||||
|
if !os.IsNotExist(err) {
|
||||||
|
return resultToJson(FunctionResult{Message: fmt.Sprintf("Could not read path: %s", err.Error())})
|
||||||
|
}
|
||||||
|
_, err = os.Create(path)
|
||||||
|
if err != nil {
|
||||||
|
return resultToJson(FunctionResult{Message: fmt.Sprintf("Could not create new file: %s", err.Error())})
|
||||||
|
}
|
||||||
|
data = []byte{}
|
||||||
|
}
|
||||||
|
|
||||||
|
if startLine < 1 {
|
||||||
|
return resultToJson(FunctionResult{Message: "start_line cannot be less than 1"})
|
||||||
|
}
|
||||||
|
|
||||||
|
lines := strings.Split(string(data), "\n")
|
||||||
|
contentLines := strings.Split(strings.Trim(content, "\n"), "\n")
|
||||||
|
|
||||||
|
if endLine == 0 || endLine > len(lines) {
|
||||||
|
endLine = len(lines)
|
||||||
|
}
|
||||||
|
|
||||||
|
before := lines[:startLine-1]
|
||||||
|
after := lines[endLine:]
|
||||||
|
|
||||||
|
lines = append(before, append(contentLines, after...)...)
|
||||||
|
newContent := strings.Join(lines, "\n")
|
||||||
|
|
||||||
|
// Join the lines and write back to the file
|
||||||
|
err = os.WriteFile(path, []byte(newContent), 0644)
|
||||||
|
if err != nil {
|
||||||
|
return resultToJson(FunctionResult{Message: fmt.Sprintf("Could not write to path: %s", err.Error())})
|
||||||
|
}
|
||||||
|
|
||||||
|
return resultToJson(FunctionResult{Result: newContent})
|
||||||
|
|
||||||
|
}
|
@ -2,7 +2,10 @@ package cli
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
|
"database/sql"
|
||||||
|
"encoding/json"
|
||||||
"errors"
|
"errors"
|
||||||
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
@ -12,23 +15,52 @@ import (
|
|||||||
func CreateChatCompletionRequest(model string, messages []Message, maxTokens int) openai.ChatCompletionRequest {
|
func CreateChatCompletionRequest(model string, messages []Message, maxTokens int) openai.ChatCompletionRequest {
|
||||||
chatCompletionMessages := []openai.ChatCompletionMessage{}
|
chatCompletionMessages := []openai.ChatCompletionMessage{}
|
||||||
for _, m := range messages {
|
for _, m := range messages {
|
||||||
chatCompletionMessages = append(chatCompletionMessages, openai.ChatCompletionMessage{
|
message := openai.ChatCompletionMessage{
|
||||||
Role: m.Role,
|
Role: string(m.Role),
|
||||||
Content: m.OriginalContent,
|
Content: m.OriginalContent,
|
||||||
})
|
}
|
||||||
|
if m.ToolCallID.Valid {
|
||||||
|
message.ToolCallID = m.ToolCallID.String
|
||||||
|
}
|
||||||
|
if m.ToolCalls.Valid {
|
||||||
|
// unmarshal directly into chatMessage.ToolCalls
|
||||||
|
err := json.Unmarshal([]byte(m.ToolCalls.String), &message.ToolCalls)
|
||||||
|
if err != nil {
|
||||||
|
// TODO: handle, this shouldn't really happen since
|
||||||
|
// we only save the successfully marshal'd data to database
|
||||||
|
fmt.Printf("Error unmarshalling the tool_calls JSON: %v\n", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
chatCompletionMessages = append(chatCompletionMessages, message)
|
||||||
}
|
}
|
||||||
|
|
||||||
return openai.ChatCompletionRequest{
|
request := openai.ChatCompletionRequest{
|
||||||
Model: model,
|
Model: model,
|
||||||
Messages: chatCompletionMessages,
|
Messages: chatCompletionMessages,
|
||||||
MaxTokens: maxTokens,
|
MaxTokens: maxTokens,
|
||||||
N: 1, // limit responses to 1 "choice". we use choices[0] to reference it
|
N: 1, // limit responses to 1 "choice". we use choices[0] to reference it
|
||||||
}
|
}
|
||||||
|
|
||||||
|
var tools []openai.Tool
|
||||||
|
for _, t := range config.OpenAI.EnabledTools {
|
||||||
|
tool, ok := AvailableTools[t]
|
||||||
|
if ok {
|
||||||
|
tools = append(tools, tool.Tool)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(tools) > 0 {
|
||||||
|
request.Tools = tools
|
||||||
|
request.ToolChoice = "auto"
|
||||||
|
}
|
||||||
|
|
||||||
|
return request
|
||||||
}
|
}
|
||||||
|
|
||||||
// CreateChatCompletion submits a Chat Completion API request and returns the
|
// CreateChatCompletion submits a Chat Completion API request and returns the
|
||||||
// response.
|
// response. CreateChatCompletion will recursively call itself in the case of
|
||||||
func CreateChatCompletion(model string, messages []Message, maxTokens int) (string, error) {
|
// tool calls, until a response is received with the final user-facing output.
|
||||||
|
func CreateChatCompletion(model string, messages []Message, maxTokens int, replies *[]Message) (string, error) {
|
||||||
client := openai.NewClient(*config.OpenAI.APIKey)
|
client := openai.NewClient(*config.OpenAI.APIKey)
|
||||||
req := CreateChatCompletionRequest(model, messages, maxTokens)
|
req := CreateChatCompletionRequest(model, messages, maxTokens)
|
||||||
resp, err := client.CreateChatCompletion(context.Background(), req)
|
resp, err := client.CreateChatCompletion(context.Background(), req)
|
||||||
@ -36,13 +68,46 @@ func CreateChatCompletion(model string, messages []Message, maxTokens int) (stri
|
|||||||
return "", err
|
return "", err
|
||||||
}
|
}
|
||||||
|
|
||||||
return resp.Choices[0].Message.Content, nil
|
choice := resp.Choices[0]
|
||||||
|
|
||||||
|
if len(choice.Message.ToolCalls) > 0 {
|
||||||
|
// Append the assistant's reply with its request for tool calls
|
||||||
|
toolCallJson, _ := json.Marshal(choice.Message.ToolCalls)
|
||||||
|
assistantReply := Message{
|
||||||
|
Role: MessageRoleAssistant,
|
||||||
|
ToolCalls: sql.NullString{String: string(toolCallJson), Valid: true},
|
||||||
|
}
|
||||||
|
|
||||||
|
toolReplies, err := ExecuteToolCalls(choice.Message.ToolCalls)
|
||||||
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
|
||||||
|
if replies != nil {
|
||||||
|
*replies = append(append(*replies, assistantReply), toolReplies...)
|
||||||
|
}
|
||||||
|
|
||||||
|
messages = append(append(messages, assistantReply), toolReplies...)
|
||||||
|
// Recurse into CreateChatCompletion with the tool call replies added
|
||||||
|
// to the original messages
|
||||||
|
return CreateChatCompletion(model, append(messages, toolReplies...), maxTokens, replies)
|
||||||
|
}
|
||||||
|
|
||||||
|
if replies != nil {
|
||||||
|
*replies = append(*replies, Message{
|
||||||
|
Role: MessageRoleAssistant,
|
||||||
|
OriginalContent: choice.Message.Content,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// Return the user-facing message.
|
||||||
|
return choice.Message.Content, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// CreateChatCompletionStream submits a streaming Chat Completion API request
|
// CreateChatCompletionStream submits a streaming Chat Completion API request
|
||||||
// and both returns and streams the response to the provided output channel.
|
// and both returns and streams the response to the provided output channel.
|
||||||
// May return a partial response if an error occurs mid-stream.
|
// May return a partial response if an error occurs mid-stream.
|
||||||
func CreateChatCompletionStream(model string, messages []Message, maxTokens int, output chan<- string) (string, error) {
|
func CreateChatCompletionStream(model string, messages []Message, maxTokens int, output chan<- string, replies *[]Message) (string, error) {
|
||||||
client := openai.NewClient(*config.OpenAI.APIKey)
|
client := openai.NewClient(*config.OpenAI.APIKey)
|
||||||
req := CreateChatCompletionRequest(model, messages, maxTokens)
|
req := CreateChatCompletionRequest(model, messages, maxTokens)
|
||||||
|
|
||||||
@ -52,7 +117,10 @@ func CreateChatCompletionStream(model string, messages []Message, maxTokens int,
|
|||||||
}
|
}
|
||||||
defer stream.Close()
|
defer stream.Close()
|
||||||
|
|
||||||
sb := strings.Builder{}
|
content := strings.Builder{}
|
||||||
|
toolCalls := []openai.ToolCall{}
|
||||||
|
|
||||||
|
// Iterate stream segments
|
||||||
for {
|
for {
|
||||||
response, e := stream.Recv()
|
response, e := stream.Recv()
|
||||||
if errors.Is(e, io.EOF) {
|
if errors.Is(e, io.EOF) {
|
||||||
@ -63,9 +131,57 @@ func CreateChatCompletionStream(model string, messages []Message, maxTokens int,
|
|||||||
err = e
|
err = e
|
||||||
break
|
break
|
||||||
}
|
}
|
||||||
chunk := response.Choices[0].Delta.Content
|
|
||||||
output <- chunk
|
delta := response.Choices[0].Delta
|
||||||
sb.WriteString(chunk)
|
if len(delta.ToolCalls) > 0 {
|
||||||
|
// Construct streamed tool_call arguments
|
||||||
|
for _, tc := range delta.ToolCalls {
|
||||||
|
if tc.Index == nil {
|
||||||
|
return "", fmt.Errorf("Unexpected nil index for streamed tool call.")
|
||||||
|
}
|
||||||
|
if len(toolCalls) <= *tc.Index {
|
||||||
|
toolCalls = append(toolCalls, tc)
|
||||||
|
} else {
|
||||||
|
toolCalls[*tc.Index].Function.Arguments += tc.Function.Arguments
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
output <- delta.Content
|
||||||
|
content.WriteString(delta.Content)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
return sb.String(), err
|
|
||||||
|
if len(toolCalls) > 0 {
|
||||||
|
// Append the assistant's reply with its request for tool calls
|
||||||
|
toolCallJson, _ := json.Marshal(toolCalls)
|
||||||
|
|
||||||
|
assistantReply := Message{
|
||||||
|
Role: MessageRoleAssistant,
|
||||||
|
OriginalContent: content.String(),
|
||||||
|
ToolCalls: sql.NullString{String: string(toolCallJson), Valid: true},
|
||||||
|
}
|
||||||
|
|
||||||
|
toolReplies, err := ExecuteToolCalls(toolCalls)
|
||||||
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
|
||||||
|
if replies != nil {
|
||||||
|
*replies = append(append(*replies, assistantReply), toolReplies...)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Recurse into CreateChatCompletionStream with the tool call replies
|
||||||
|
// added to the original messages
|
||||||
|
messages = append(append(messages, assistantReply), toolReplies...)
|
||||||
|
return CreateChatCompletionStream(model, messages, maxTokens, output, replies)
|
||||||
|
}
|
||||||
|
|
||||||
|
if replies != nil {
|
||||||
|
*replies = append(*replies, Message{
|
||||||
|
Role: MessageRoleAssistant,
|
||||||
|
OriginalContent: content.String(),
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
return content.String(), err
|
||||||
}
|
}
|
||||||
|
@ -23,8 +23,10 @@ type Message struct {
|
|||||||
ConversationID uint `gorm:"foreignKey:ConversationID"`
|
ConversationID uint `gorm:"foreignKey:ConversationID"`
|
||||||
Conversation Conversation
|
Conversation Conversation
|
||||||
OriginalContent string
|
OriginalContent string
|
||||||
Role string // 'user' or 'assistant'
|
Role MessageRole // one of: 'system', 'user', 'assistant', 'tool'
|
||||||
CreatedAt time.Time
|
CreatedAt time.Time
|
||||||
|
ToolCallID sql.NullString
|
||||||
|
ToolCalls sql.NullString // a json-encoded array of tool calls from the model
|
||||||
}
|
}
|
||||||
|
|
||||||
type Conversation struct {
|
type Conversation struct {
|
||||||
@ -95,6 +97,10 @@ func (s *Store) SaveMessage(message *Message) error {
|
|||||||
return s.db.Create(message).Error
|
return s.db.Create(message).Error
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (s *Store) DeleteMessage(message *Message) error {
|
||||||
|
return s.db.Delete(&message).Error
|
||||||
|
}
|
||||||
|
|
||||||
func (s *Store) Conversations() ([]Conversation, error) {
|
func (s *Store) Conversations() ([]Conversation, error) {
|
||||||
var conversations []Conversation
|
var conversations []Conversation
|
||||||
err := s.db.Find(&conversations).Error
|
err := s.db.Find(&conversations).Error
|
||||||
|
Loading…
Reference in New Issue
Block a user