From 5c6ec5e4e289a40777e7a60af8c592261a70451b Mon Sep 17 00:00:00 2001 From: Matt Low Date: Sat, 4 Nov 2023 22:07:06 +0000 Subject: [PATCH] Include system prompt in OpenAI chat completion requests --- pkg/cli/openai.go | 22 ++++++++++++++-------- 1 file changed, 14 insertions(+), 8 deletions(-) diff --git a/pkg/cli/openai.go b/pkg/cli/openai.go index 4d90bf6..7853b23 100644 --- a/pkg/cli/openai.go +++ b/pkg/cli/openai.go @@ -8,8 +8,14 @@ import ( openai "github.com/sashabaranov/go-openai" ) -func CreateChatCompletionRequest(messages []Message) (openai.ChatCompletionRequest) { - var chatCompletionMessages []openai.ChatCompletionMessage +func CreateChatCompletionRequest(system string, messages []Message) (*openai.ChatCompletionRequest) { + chatCompletionMessages := []openai.ChatCompletionMessage{ + { + Role: "system", + Content: system, + }, + } + for _, m := range(messages) { chatCompletionMessages = append(chatCompletionMessages, openai.ChatCompletionMessage{ Role: m.Role, @@ -17,7 +23,7 @@ func CreateChatCompletionRequest(messages []Message) (openai.ChatCompletionReque }) } - return openai.ChatCompletionRequest{ + return &openai.ChatCompletionRequest{ Model: openai.GPT4, MaxTokens: 256, Messages: chatCompletionMessages, @@ -31,24 +37,24 @@ func CreateChatCompletion(system string, messages []Message) (string, error) { client := openai.NewClient(config.OpenAI.APIKey) resp, err := client.CreateChatCompletion( context.Background(), - CreateChatCompletionRequest(messages), + *CreateChatCompletionRequest(system, messages), ) if err != nil { - return "", fmt.Errorf("ChatCompletion error: %v\n", err) + return "", err } return resp.Choices[0].Message.Content, nil } -func CreateChatCompletionStream(system string, messages []Message, output io.Writer) (error) { +func CreateChatCompletionStream(system string, messages []Message, output io.Writer) error { client := openai.NewClient(config.OpenAI.APIKey) ctx := context.Background() - req := CreateChatCompletionRequest(messages) + req := CreateChatCompletionRequest(system, messages) req.Stream = true - stream, err := client.CreateChatCompletionStream(ctx, req) + stream, err := client.CreateChatCompletionStream(ctx, *req) if err != nil { return err }