From 2c64ab501bce82eaa15f9a082ccb88f1146be910 Mon Sep 17 00:00:00 2001 From: Matt Low Date: Sun, 5 Nov 2023 07:54:12 +0000 Subject: [PATCH] Treat the system message like any other Removed the system parameter on ChatCopmletion functions, and persist it in conversations as well. --- pkg/cli/cmd.go | 34 ++++++++++++++++++++++++---------- pkg/cli/openai.go | 18 ++++++------------ 2 files changed, 30 insertions(+), 22 deletions(-) diff --git a/pkg/cli/cmd.go b/pkg/cli/cmd.go index 63c69a9..63faccc 100644 --- a/pkg/cli/cmd.go +++ b/pkg/cli/cmd.go @@ -84,17 +84,26 @@ var newCmd = &cobra.Command{ Fatal("Could not save new conversation: %v\n", err) } - message := Message{ - ConversationID: conversation.ID, - Role: "user", - OriginalContent: messageContents, + const system = "You are a helpful assistant." + messages := []Message{ + { + ConversationID: conversation.ID, + Role: "system", + OriginalContent: system, + }, + { + ConversationID: conversation.ID, + Role: "user", + OriginalContent: messageContents, + }, } - err = store.SaveMessage(&message) - if err != nil { - Warn("Could not save message: %v\n", err) + for _, message := range(messages) { + err = store.SaveMessage(&message) + if err != nil { + Warn("Could not save %s message: %v\n", message.Role, err) + } } - const system = "You are a helpful assistant." fmt.Printf("\n\n%s\n\n", system) fmt.Printf("\n\n%s\n\n", messageContents) fmt.Print("\n\n") @@ -105,7 +114,7 @@ var newCmd = &cobra.Command{ response <- HandleDelayedResponse(receiver) }() - err = CreateChatCompletionStream(system, []Message{message}, receiver) + err = CreateChatCompletionStream(messages, receiver) if err != nil { Fatal("%v\n", err) } @@ -134,7 +143,12 @@ var promptCmd = &cobra.Command{ Fatal("No message was provided.\n") } + const system = "You are a helpful assistant." messages := []Message{ + { + Role: "system", + OriginalContent: system, + }, { Role: "user", OriginalContent: message, @@ -143,7 +157,7 @@ var promptCmd = &cobra.Command{ receiver := make(chan string) go HandleDelayedResponse(receiver) - err := CreateChatCompletionStream("You are a helpful assistant.", messages, receiver) + err := CreateChatCompletionStream(messages, receiver) if err != nil { Fatal("%v\n", err) } diff --git a/pkg/cli/openai.go b/pkg/cli/openai.go index 060bf63..1251426 100644 --- a/pkg/cli/openai.go +++ b/pkg/cli/openai.go @@ -8,14 +8,8 @@ import ( openai "github.com/sashabaranov/go-openai" ) -func CreateChatCompletionRequest(system string, messages []Message) *openai.ChatCompletionRequest { - chatCompletionMessages := []openai.ChatCompletionMessage{ - { - Role: "system", - Content: system, - }, - } - +func CreateChatCompletionRequest(messages []Message) *openai.ChatCompletionRequest { + chatCompletionMessages := []openai.ChatCompletionMessage{} for _, m := range messages { chatCompletionMessages = append(chatCompletionMessages, openai.ChatCompletionMessage{ Role: m.Role, @@ -33,11 +27,11 @@ func CreateChatCompletionRequest(system string, messages []Message) *openai.Chat // CreateChatCompletion accepts a slice of Message and returns the response // of the Large Language Model. -func CreateChatCompletion(system string, messages []Message) (string, error) { +func CreateChatCompletion(messages []Message) (string, error) { client := openai.NewClient(config.OpenAI.APIKey) resp, err := client.CreateChatCompletion( context.Background(), - *CreateChatCompletionRequest(system, messages), + *CreateChatCompletionRequest(messages), ) if err != nil { @@ -49,11 +43,11 @@ func CreateChatCompletion(system string, messages []Message) (string, error) { // CreateChatCompletionStream submits an streaming Chat Completion API request // and sends the received data to the output channel. -func CreateChatCompletionStream(system string, messages []Message, output chan string) error { +func CreateChatCompletionStream(messages []Message, output chan string) error { client := openai.NewClient(config.OpenAI.APIKey) ctx := context.Background() - req := CreateChatCompletionRequest(system, messages) + req := CreateChatCompletionRequest(messages) req.Stream = true defer close(output)