package cli import ( "context" "database/sql" "encoding/json" "errors" "fmt" "io" "strings" openai "github.com/sashabaranov/go-openai" ) func CreateChatCompletionRequest(model string, messages []Message, maxTokens int) openai.ChatCompletionRequest { chatCompletionMessages := []openai.ChatCompletionMessage{} for _, m := range messages { message := openai.ChatCompletionMessage{ Role: m.Role, Content: m.OriginalContent, } if m.ToolCallID.Valid { message.ToolCallID = m.ToolCallID.String } if m.ToolCalls.Valid { // unmarshal directly into chatMessage.ToolCalls err := json.Unmarshal([]byte(m.ToolCalls.String), &message.ToolCalls) if err != nil { // TODO: handle, this shouldn't really happen since // we only save the successfully marshal'd data to database fmt.Printf("Error unmarshalling the tool_calls JSON: %v\n", err) } } chatCompletionMessages = append(chatCompletionMessages, message) } var tools []openai.Tool for _, t := range AvailableTools { // TODO: support some way to limit which tools are available per-request tools = append(tools, t.Tool) } return openai.ChatCompletionRequest{ Model: model, Messages: chatCompletionMessages, MaxTokens: maxTokens, N: 1, // limit responses to 1 "choice". we use choices[0] to reference it Tools: tools, ToolChoice: "auto", // TODO: allow limiting/forcing which function is called? } } // CreateChatCompletion submits a Chat Completion API request and returns the // response. CreateChatCompletion will recursively call itself in the case of // tool calls, until a response is received with the final user-facing output. func CreateChatCompletion(model string, messages []Message, maxTokens int) (string, error) { client := openai.NewClient(*config.OpenAI.APIKey) req := CreateChatCompletionRequest(model, messages, maxTokens) resp, err := client.CreateChatCompletion(context.Background(), req) if err != nil { return "", err } choice := resp.Choices[0] if len(choice.Message.ToolCalls) > 0 { if choice.Message.Content != "" { return "", fmt.Errorf("Model replied with user-facing content in addition to tool calls. Unsupported.") } // Append the assistant's reply with its request for tool calls toolCallJson, _ := json.Marshal(choice.Message.ToolCalls) messages = append(messages, Message{ Role: "assistant", ToolCalls: sql.NullString{String: string(toolCallJson), Valid: true}, }) toolReplies, err := ExecuteToolCalls(choice.Message.ToolCalls) if err != nil { return "", err } // Recurse into CreateChatCompletion with the tool call replies added // to the original messages return CreateChatCompletion(model, append(messages, toolReplies...), maxTokens) } // Return the user-facing message. return choice.Message.Content, nil } // CreateChatCompletionStream submits a streaming Chat Completion API request // and both returns and streams the response to the provided output channel. // May return a partial response if an error occurs mid-stream. func CreateChatCompletionStream(model string, messages []Message, maxTokens int, output chan<- string) (string, error) { client := openai.NewClient(*config.OpenAI.APIKey) req := CreateChatCompletionRequest(model, messages, maxTokens) stream, err := client.CreateChatCompletionStream(context.Background(), req) if err != nil { return "", err } defer stream.Close() content := strings.Builder{} toolCalls := []openai.ToolCall{} // Iterate stream segments for { response, e := stream.Recv() if errors.Is(e, io.EOF) { break } if e != nil { err = e break } delta := response.Choices[0].Delta if len(delta.ToolCalls) > 0 { // Construct streamed tool_call arguments for _, tc := range delta.ToolCalls { if tc.Index == nil { return "", fmt.Errorf("Unexpected nil index for streamed tool call.") } if len(toolCalls) <= *tc.Index { toolCalls = append(toolCalls, tc) } else { toolCalls[*tc.Index].Function.Arguments += tc.Function.Arguments } } } else { output <- delta.Content content.WriteString(delta.Content) } } if len(toolCalls) > 0 { if content.String() != "" { return "", fmt.Errorf("Model replied with user-facing content in addition to tool calls. Unsupported.") } // Append the assistant's reply with its request for tool calls toolCallJson, _ := json.Marshal(toolCalls) messages = append(messages, Message{ Role: "assistant", ToolCalls: sql.NullString{String: string(toolCallJson), Valid: true}, }) toolReplies, err := ExecuteToolCalls(toolCalls) if err != nil { return "", err } // Recurse into CreateChatCompletionStream with the tool call replies // added to the original messages return CreateChatCompletionStream(model, append(messages, toolReplies...), maxTokens, output) } return content.String(), err }