Matt Low
b229c42811
So far only supported on the non-streaming endpoint. Added the `read_dir` tool for reading contents from paths relative to the current working directory.
123 lines
3.7 KiB
Go
123 lines
3.7 KiB
Go
package cli
|
|
|
|
import (
|
|
"context"
|
|
"database/sql"
|
|
"encoding/json"
|
|
"errors"
|
|
"fmt"
|
|
"io"
|
|
"strings"
|
|
|
|
openai "github.com/sashabaranov/go-openai"
|
|
)
|
|
|
|
func CreateChatCompletionRequest(model string, messages []Message, maxTokens int) openai.ChatCompletionRequest {
|
|
chatCompletionMessages := []openai.ChatCompletionMessage{}
|
|
for _, m := range messages {
|
|
message := openai.ChatCompletionMessage{
|
|
Role: m.Role,
|
|
Content: m.OriginalContent,
|
|
}
|
|
if m.ToolCallID.Valid {
|
|
message.ToolCallID = m.ToolCallID.String
|
|
}
|
|
if m.ToolCalls.Valid {
|
|
// unmarshal directly into chatMessage.ToolCalls
|
|
err := json.Unmarshal([]byte(m.ToolCalls.String), &message.ToolCalls)
|
|
if err != nil {
|
|
// TODO: handle, this shouldn't really happen since
|
|
// we only save the successfully marshal'd data to database
|
|
fmt.Printf("Error unmarshalling the tool_calls JSON: %v\n", err)
|
|
}
|
|
}
|
|
chatCompletionMessages = append(chatCompletionMessages, message)
|
|
}
|
|
|
|
var tools []openai.Tool
|
|
for _, t := range AvailableTools {
|
|
// TODO: support some way to limit which tools are available per-request
|
|
tools = append(tools, t.Tool)
|
|
}
|
|
|
|
return openai.ChatCompletionRequest{
|
|
Model: model,
|
|
Messages: chatCompletionMessages,
|
|
MaxTokens: maxTokens,
|
|
N: 1, // limit responses to 1 "choice". we use choices[0] to reference it
|
|
Tools: tools,
|
|
ToolChoice: "auto", // TODO: allow limiting/forcing which function is called?
|
|
}
|
|
}
|
|
|
|
// CreateChatCompletion submits a Chat Completion API request and returns the
|
|
// response. CreateChatCompletion will recursively call itself in the case of
|
|
// tool calls, until a response is received with the final user-facing output.
|
|
func CreateChatCompletion(model string, messages []Message, maxTokens int) (string, error) {
|
|
client := openai.NewClient(*config.OpenAI.APIKey)
|
|
req := CreateChatCompletionRequest(model, messages, maxTokens)
|
|
resp, err := client.CreateChatCompletion(context.Background(), req)
|
|
if err != nil {
|
|
return "", err
|
|
}
|
|
|
|
choice := resp.Choices[0]
|
|
|
|
if len(choice.Message.ToolCalls) > 0 {
|
|
if choice.Message.Content != "" {
|
|
return "", fmt.Errorf("Model replied with user-facing content in addition to tool calls. Unsupported.")
|
|
}
|
|
|
|
// Append the assistant's reply with its request for tool calls
|
|
toolCallJson, _ := json.Marshal(choice.Message.ToolCalls)
|
|
messages = append(messages, Message{
|
|
ConversationID: messages[0].ConversationID,
|
|
Role: "assistant",
|
|
ToolCalls: sql.NullString{String: string(toolCallJson), Valid: true},
|
|
})
|
|
|
|
toolReplies, err := ExecuteToolCalls(choice.Message.ToolCalls)
|
|
if err != nil {
|
|
return "", err
|
|
}
|
|
|
|
// Recurse into CreateChatCompletion with the tool call replies added
|
|
// to the original messages
|
|
return CreateChatCompletion(model, append(messages, toolReplies...), maxTokens)
|
|
}
|
|
|
|
// Return the user-facing message.
|
|
return choice.Message.Content, nil
|
|
}
|
|
|
|
// CreateChatCompletionStream submits a streaming Chat Completion API request
|
|
// and both returns and streams the response to the provided output channel.
|
|
// May return a partial response if an error occurs mid-stream.
|
|
func CreateChatCompletionStream(model string, messages []Message, maxTokens int, output chan<- string) (string, error) {
|
|
client := openai.NewClient(*config.OpenAI.APIKey)
|
|
req := CreateChatCompletionRequest(model, messages, maxTokens)
|
|
|
|
stream, err := client.CreateChatCompletionStream(context.Background(), req)
|
|
if err != nil {
|
|
return "", err
|
|
}
|
|
defer stream.Close()
|
|
|
|
sb := strings.Builder{}
|
|
for {
|
|
response, e := stream.Recv()
|
|
if errors.Is(e, io.EOF) {
|
|
break
|
|
}
|
|
|
|
if e != nil {
|
|
err = e
|
|
break
|
|
}
|
|
chunk := response.Choices[0].Delta.Content
|
|
output <- chunk
|
|
sb.WriteString(chunk)
|
|
}
|
|
return sb.String(), err
|
|
}
|