lmcli/pkg/cli/openai.go

163 lines
4.8 KiB
Go
Raw Normal View History

package cli
2023-10-30 15:23:07 -06:00
import (
"context"
"database/sql"
"encoding/json"
2023-10-30 15:45:21 -06:00
"errors"
"fmt"
2023-10-30 15:45:21 -06:00
"io"
"strings"
2023-11-04 16:56:22 -06:00
2023-10-30 15:23:07 -06:00
openai "github.com/sashabaranov/go-openai"
)
func CreateChatCompletionRequest(model string, messages []Message, maxTokens int) openai.ChatCompletionRequest {
chatCompletionMessages := []openai.ChatCompletionMessage{}
2023-11-04 16:56:22 -06:00
for _, m := range messages {
message := openai.ChatCompletionMessage{
2023-11-04 16:56:22 -06:00
Role: m.Role,
2023-10-30 15:23:07 -06:00
Content: m.OriginalContent,
}
if m.ToolCallID.Valid {
message.ToolCallID = m.ToolCallID.String
}
if m.ToolCalls.Valid {
// unmarshal directly into chatMessage.ToolCalls
err := json.Unmarshal([]byte(m.ToolCalls.String), &message.ToolCalls)
if err != nil {
// TODO: handle, this shouldn't really happen since
// we only save the successfully marshal'd data to database
fmt.Printf("Error unmarshalling the tool_calls JSON: %v\n", err)
}
}
chatCompletionMessages = append(chatCompletionMessages, message)
}
var tools []openai.Tool
for _, t := range AvailableTools {
// TODO: support some way to limit which tools are available per-request
tools = append(tools, t.Tool)
2023-10-30 15:23:07 -06:00
}
return openai.ChatCompletionRequest{
Model: model,
Messages: chatCompletionMessages,
MaxTokens: maxTokens,
N: 1, // limit responses to 1 "choice". we use choices[0] to reference it
Tools: tools,
ToolChoice: "auto", // TODO: allow limiting/forcing which function is called?
}
}
// CreateChatCompletion submits a Chat Completion API request and returns the
// response. CreateChatCompletion will recursively call itself in the case of
// tool calls, until a response is received with the final user-facing output.
func CreateChatCompletion(model string, messages []Message, maxTokens int) (string, error) {
client := openai.NewClient(*config.OpenAI.APIKey)
req := CreateChatCompletionRequest(model, messages, maxTokens)
resp, err := client.CreateChatCompletion(context.Background(), req)
2023-10-30 15:23:07 -06:00
if err != nil {
return "", err
2023-10-30 15:23:07 -06:00
}
choice := resp.Choices[0]
if len(choice.Message.ToolCalls) > 0 {
if choice.Message.Content != "" {
return "", fmt.Errorf("Model replied with user-facing content in addition to tool calls. Unsupported.")
}
// Append the assistant's reply with its request for tool calls
toolCallJson, _ := json.Marshal(choice.Message.ToolCalls)
messages = append(messages, Message{
Role: "assistant",
ToolCalls: sql.NullString{String: string(toolCallJson), Valid: true},
})
toolReplies, err := ExecuteToolCalls(choice.Message.ToolCalls)
if err != nil {
return "", err
}
// Recurse into CreateChatCompletion with the tool call replies added
// to the original messages
return CreateChatCompletion(model, append(messages, toolReplies...), maxTokens)
}
// Return the user-facing message.
return choice.Message.Content, nil
2023-10-30 15:23:07 -06:00
}
2023-10-30 15:45:21 -06:00
// CreateChatCompletionStream submits a streaming Chat Completion API request
// and both returns and streams the response to the provided output channel.
// May return a partial response if an error occurs mid-stream.
func CreateChatCompletionStream(model string, messages []Message, maxTokens int, output chan<- string) (string, error) {
client := openai.NewClient(*config.OpenAI.APIKey)
req := CreateChatCompletionRequest(model, messages, maxTokens)
2023-10-30 15:45:21 -06:00
stream, err := client.CreateChatCompletionStream(context.Background(), req)
2023-10-30 15:45:21 -06:00
if err != nil {
return "", err
2023-10-30 15:45:21 -06:00
}
defer stream.Close()
content := strings.Builder{}
toolCalls := []openai.ToolCall{}
// Iterate stream segments
2023-10-30 15:45:21 -06:00
for {
response, e := stream.Recv()
if errors.Is(e, io.EOF) {
break
2023-10-30 15:45:21 -06:00
}
if e != nil {
err = e
break
2023-10-30 15:45:21 -06:00
}
delta := response.Choices[0].Delta
if len(delta.ToolCalls) > 0 {
// Construct streamed tool_call arguments
for _, tc := range delta.ToolCalls {
if tc.Index == nil {
return "", fmt.Errorf("Unexpected nil index for streamed tool call.")
}
if len(toolCalls) <= *tc.Index {
toolCalls = append(toolCalls, tc)
} else {
toolCalls[*tc.Index].Function.Arguments += tc.Function.Arguments
}
}
} else {
output <- delta.Content
content.WriteString(delta.Content)
}
2023-10-30 15:45:21 -06:00
}
if len(toolCalls) > 0 {
if content.String() != "" {
return "", fmt.Errorf("Model replied with user-facing content in addition to tool calls. Unsupported.")
}
// Append the assistant's reply with its request for tool calls
toolCallJson, _ := json.Marshal(toolCalls)
messages = append(messages, Message{
Role: "assistant",
ToolCalls: sql.NullString{String: string(toolCallJson), Valid: true},
})
toolReplies, err := ExecuteToolCalls(toolCalls)
if err != nil {
return "", err
}
// Recurse into CreateChatCompletionStream with the tool call replies
// added to the original messages
return CreateChatCompletionStream(model, append(messages, toolReplies...), maxTokens, output)
}
return content.String(), err
2023-10-30 15:45:21 -06:00
}