package cli import ( "context" "database/sql" "encoding/json" "errors" "fmt" "io" "strings" openai "github.com/sashabaranov/go-openai" ) func CreateChatCompletionRequest(model string, messages []Message, maxTokens int) openai.ChatCompletionRequest { chatCompletionMessages := []openai.ChatCompletionMessage{} for _, m := range messages { message := openai.ChatCompletionMessage{ Role: string(m.Role), Content: m.OriginalContent, } if m.ToolCallID.Valid { message.ToolCallID = m.ToolCallID.String } if m.ToolCalls.Valid { // unmarshal directly into chatMessage.ToolCalls err := json.Unmarshal([]byte(m.ToolCalls.String), &message.ToolCalls) if err != nil { // TODO: handle, this shouldn't really happen since // we only save the successfully marshal'd data to database fmt.Printf("Error unmarshalling the tool_calls JSON: %v\n", err) } } chatCompletionMessages = append(chatCompletionMessages, message) } request := openai.ChatCompletionRequest{ Model: model, Messages: chatCompletionMessages, MaxTokens: maxTokens, N: 1, // limit responses to 1 "choice". we use choices[0] to reference it } var tools []openai.Tool for _, t := range config.OpenAI.EnabledTools { tool, ok := AvailableTools[t] if ok { tools = append(tools, tool.Tool) } } if len(tools) > 0 { request.Tools = tools request.ToolChoice = "auto" } return request } // CreateChatCompletion submits a Chat Completion API request and returns the // response. CreateChatCompletion will recursively call itself in the case of // tool calls, until a response is received with the final user-facing output. func CreateChatCompletion(model string, messages []Message, maxTokens int, replies *[]Message) (string, error) { client := openai.NewClient(*config.OpenAI.APIKey) req := CreateChatCompletionRequest(model, messages, maxTokens) resp, err := client.CreateChatCompletion(context.Background(), req) if err != nil { return "", err } choice := resp.Choices[0] if len(choice.Message.ToolCalls) > 0 { // Append the assistant's reply with its request for tool calls toolCallJson, _ := json.Marshal(choice.Message.ToolCalls) assistantReply := Message{ Role: MessageRoleAssistant, ToolCalls: sql.NullString{String: string(toolCallJson), Valid: true}, } toolReplies, err := ExecuteToolCalls(choice.Message.ToolCalls) if err != nil { return "", err } if replies != nil { *replies = append(append(*replies, assistantReply), toolReplies...) } messages = append(append(messages, assistantReply), toolReplies...) // Recurse into CreateChatCompletion with the tool call replies added // to the original messages return CreateChatCompletion(model, append(messages, toolReplies...), maxTokens, replies) } if replies != nil { *replies = append(*replies, Message{ Role: MessageRoleAssistant, OriginalContent: choice.Message.Content, }) } // Return the user-facing message. return choice.Message.Content, nil } // CreateChatCompletionStream submits a streaming Chat Completion API request // and both returns and streams the response to the provided output channel. // May return a partial response if an error occurs mid-stream. func CreateChatCompletionStream(model string, messages []Message, maxTokens int, output chan<- string, replies *[]Message) (string, error) { client := openai.NewClient(*config.OpenAI.APIKey) req := CreateChatCompletionRequest(model, messages, maxTokens) stream, err := client.CreateChatCompletionStream(context.Background(), req) if err != nil { return "", err } defer stream.Close() content := strings.Builder{} toolCalls := []openai.ToolCall{} // Iterate stream segments for { response, e := stream.Recv() if errors.Is(e, io.EOF) { break } if e != nil { err = e break } delta := response.Choices[0].Delta if len(delta.ToolCalls) > 0 { // Construct streamed tool_call arguments for _, tc := range delta.ToolCalls { if tc.Index == nil { return "", fmt.Errorf("Unexpected nil index for streamed tool call.") } if len(toolCalls) <= *tc.Index { toolCalls = append(toolCalls, tc) } else { toolCalls[*tc.Index].Function.Arguments += tc.Function.Arguments } } } else { output <- delta.Content content.WriteString(delta.Content) } } if len(toolCalls) > 0 { // Append the assistant's reply with its request for tool calls toolCallJson, _ := json.Marshal(toolCalls) assistantReply := Message{ Role: MessageRoleAssistant, OriginalContent: content.String(), ToolCalls: sql.NullString{String: string(toolCallJson), Valid: true}, } toolReplies, err := ExecuteToolCalls(toolCalls) if err != nil { return "", err } if replies != nil { *replies = append(append(*replies, assistantReply), toolReplies...) } // Recurse into CreateChatCompletionStream with the tool call replies // added to the original messages messages = append(append(messages, assistantReply), toolReplies...) return CreateChatCompletionStream(model, messages, maxTokens, output, replies) } if replies != nil { *replies = append(*replies, Message{ Role: MessageRoleAssistant, OriginalContent: content.String(), }) } return content.String(), err }