lmcli/pkg/cli/openai.go

188 lines
5.2 KiB
Go
Raw Normal View History

package cli
2023-10-30 15:23:07 -06:00
import (
"context"
"database/sql"
"encoding/json"
2023-10-30 15:45:21 -06:00
"errors"
"fmt"
2023-10-30 15:45:21 -06:00
"io"
"strings"
2023-11-04 16:56:22 -06:00
2023-10-30 15:23:07 -06:00
openai "github.com/sashabaranov/go-openai"
)
func CreateChatCompletionRequest(model string, messages []Message, maxTokens int) openai.ChatCompletionRequest {
chatCompletionMessages := []openai.ChatCompletionMessage{}
2023-11-04 16:56:22 -06:00
for _, m := range messages {
message := openai.ChatCompletionMessage{
Role: string(m.Role),
2023-10-30 15:23:07 -06:00
Content: m.OriginalContent,
}
if m.ToolCallID.Valid {
message.ToolCallID = m.ToolCallID.String
}
if m.ToolCalls.Valid {
// unmarshal directly into chatMessage.ToolCalls
err := json.Unmarshal([]byte(m.ToolCalls.String), &message.ToolCalls)
if err != nil {
// TODO: handle, this shouldn't really happen since
// we only save the successfully marshal'd data to database
fmt.Printf("Error unmarshalling the tool_calls JSON: %v\n", err)
}
}
chatCompletionMessages = append(chatCompletionMessages, message)
}
request := openai.ChatCompletionRequest{
Model: model,
Messages: chatCompletionMessages,
MaxTokens: maxTokens,
N: 1, // limit responses to 1 "choice". we use choices[0] to reference it
}
var tools []openai.Tool
for _, t := range config.OpenAI.EnabledTools {
tool, ok := AvailableTools[t]
if ok {
tools = append(tools, tool.Tool)
}
2023-10-30 15:23:07 -06:00
}
if len(tools) > 0 {
request.Tools = tools
request.ToolChoice = "auto"
}
return request
}
// CreateChatCompletion submits a Chat Completion API request and returns the
// response. CreateChatCompletion will recursively call itself in the case of
// tool calls, until a response is received with the final user-facing output.
func CreateChatCompletion(model string, messages []Message, maxTokens int, replies *[]Message) (string, error) {
client := openai.NewClient(*config.OpenAI.APIKey)
req := CreateChatCompletionRequest(model, messages, maxTokens)
resp, err := client.CreateChatCompletion(context.Background(), req)
2023-10-30 15:23:07 -06:00
if err != nil {
return "", err
2023-10-30 15:23:07 -06:00
}
choice := resp.Choices[0]
if len(choice.Message.ToolCalls) > 0 {
// Append the assistant's reply with its request for tool calls
toolCallJson, _ := json.Marshal(choice.Message.ToolCalls)
assistantReply := Message{
Role: MessageRoleAssistant,
ToolCalls: sql.NullString{String: string(toolCallJson), Valid: true},
}
toolReplies, err := ExecuteToolCalls(choice.Message.ToolCalls)
if err != nil {
return "", err
}
if replies != nil {
*replies = append(append(*replies, assistantReply), toolReplies...)
}
messages = append(append(messages, assistantReply), toolReplies...)
// Recurse into CreateChatCompletion with the tool call replies added
// to the original messages
return CreateChatCompletion(model, messages, maxTokens, replies)
}
if replies != nil {
*replies = append(*replies, Message{
Role: MessageRoleAssistant,
OriginalContent: choice.Message.Content,
})
}
// Return the user-facing message.
return choice.Message.Content, nil
2023-10-30 15:23:07 -06:00
}
2023-10-30 15:45:21 -06:00
// CreateChatCompletionStream submits a streaming Chat Completion API request
// and both returns and streams the response to the provided output channel.
// May return a partial response if an error occurs mid-stream.
func CreateChatCompletionStream(model string, messages []Message, maxTokens int, output chan<- string, replies *[]Message) (string, error) {
client := openai.NewClient(*config.OpenAI.APIKey)
req := CreateChatCompletionRequest(model, messages, maxTokens)
2023-10-30 15:45:21 -06:00
stream, err := client.CreateChatCompletionStream(context.Background(), req)
2023-10-30 15:45:21 -06:00
if err != nil {
return "", err
2023-10-30 15:45:21 -06:00
}
defer stream.Close()
content := strings.Builder{}
toolCalls := []openai.ToolCall{}
// Iterate stream segments
2023-10-30 15:45:21 -06:00
for {
response, e := stream.Recv()
if errors.Is(e, io.EOF) {
break
2023-10-30 15:45:21 -06:00
}
if e != nil {
err = e
break
2023-10-30 15:45:21 -06:00
}
delta := response.Choices[0].Delta
if len(delta.ToolCalls) > 0 {
// Construct streamed tool_call arguments
for _, tc := range delta.ToolCalls {
if tc.Index == nil {
return "", fmt.Errorf("Unexpected nil index for streamed tool call.")
}
if len(toolCalls) <= *tc.Index {
toolCalls = append(toolCalls, tc)
} else {
toolCalls[*tc.Index].Function.Arguments += tc.Function.Arguments
}
}
} else {
output <- delta.Content
content.WriteString(delta.Content)
}
2023-10-30 15:45:21 -06:00
}
if len(toolCalls) > 0 {
// Append the assistant's reply with its request for tool calls
toolCallJson, _ := json.Marshal(toolCalls)
assistantReply := Message{
Role: MessageRoleAssistant,
OriginalContent: content.String(),
ToolCalls: sql.NullString{String: string(toolCallJson), Valid: true},
}
toolReplies, err := ExecuteToolCalls(toolCalls)
if err != nil {
return "", err
}
if replies != nil {
*replies = append(append(*replies, assistantReply), toolReplies...)
}
// Recurse into CreateChatCompletionStream with the tool call replies
// added to the original messages
messages = append(append(messages, assistantReply), toolReplies...)
return CreateChatCompletionStream(model, messages, maxTokens, output, replies)
}
if replies != nil {
*replies = append(*replies, Message{
Role: MessageRoleAssistant,
OriginalContent: content.String(),
})
}
return content.String(), err
2023-10-30 15:45:21 -06:00
}