From ed6ee9bea95ae42603353a3713169285d013fa03 Mon Sep 17 00:00:00 2001 From: Matt Low Date: Wed, 29 Nov 2023 04:43:53 +0000 Subject: [PATCH] Add *Message[] parameter to CreateChatCompletion methods Allows replies (tool calls, user-facing messges) to be added in sequence as CreateChatCompleion* recurses into itself. Cleaned up cmd.go: no longer need to create a Message based on the string content response. --- pkg/cli/cmd.go | 98 ++++++++++++++++++++++------------------- pkg/cli/conversation.go | 20 ++++++--- pkg/cli/openai.go | 58 +++++++++++++++--------- pkg/cli/store.go | 2 +- 4 files changed, 105 insertions(+), 73 deletions(-) diff --git a/pkg/cli/cmd.go b/pkg/cli/cmd.go index d30cef2..93e8d62 100644 --- a/pkg/cli/cmd.go +++ b/pkg/cli/cmd.go @@ -53,9 +53,10 @@ func SystemPrompt() string { return systemPrompt } -// LLMRequest prompts the LLM with the given Message, writes the (partial) -// response to stdout, and returns the (partial) response or any errors. -func LLMRequest(messages []Message) (string, error) { +// LLMRequest prompts the LLM with the given messages, writing the response +// to stdout. Returns all reply messages added by the LLM, including any +// function call messages. +func LLMRequest(messages []Message) ([]Message, error) { // receiver receives the reponse from LLM receiver := make(chan string) defer close(receiver) @@ -63,7 +64,8 @@ func LLMRequest(messages []Message) (string, error) { // start HandleDelayedContent goroutine to print received data to stdout go HandleDelayedContent(receiver) - response, err := CreateChatCompletionStream(model, messages, maxTokens, receiver) + var replies []Message + response, err := CreateChatCompletionStream(model, messages, maxTokens, receiver, &replies) if response != "" { if err != nil { Warn("Received partial response. Error: %v\n", err) @@ -73,7 +75,7 @@ func LLMRequest(messages []Message) (string, error) { fmt.Println() } - return response, err + return replies, err } // InputFromArgsOrEditor returns either the provided input from the args slice @@ -316,22 +318,23 @@ var replyCmd = &cobra.Command{ messages = append(messages, userReply) RenderConversation(messages, true) - assistantReply := Message{ - ConversationID: conversation.ID, - Role: "assistant", - } - assistantReply.RenderTTY() + // output an message heading + (&Message{ + Role: MessageRoleAssistant, + }).RenderTTY() - response, err := LLMRequest(messages) + replies, err := LLMRequest(messages) if err != nil { Fatal("Error fetching LLM response: %v\n", err) } - assistantReply.OriginalContent = response + for _, reply := range replies { + reply.ConversationID = conversation.ID - err = store.SaveMessage(&assistantReply) - if err != nil { - Fatal("Could not save assistant reply: %v\n", err) + err = store.SaveMessage(&reply) + if err != nil { + Warn("Could not save reply: %v\n", err) + } } }, ValidArgsFunction: func(cmd *cobra.Command, args []string, toComplete string) ([]string, cobra.ShellCompDirective) { @@ -379,22 +382,24 @@ var newCmd = &cobra.Command{ } RenderConversation(messages, true) - reply := Message{ - ConversationID: conversation.ID, - Role: "assistant", - } - reply.RenderTTY() - response, err := LLMRequest(messages) + // output an message heading + (&Message{ + Role: MessageRoleAssistant, + }).RenderTTY() + + replies, err := LLMRequest(messages) if err != nil { Fatal("Error fetching LLM response: %v\n", err) } - reply.OriginalContent = response + for _, reply := range replies { + reply.ConversationID = conversation.ID - err = store.SaveMessage(&reply) - if err != nil { - Fatal("Could not save reply: %v\n", err) + err = store.SaveMessage(&reply) + if err != nil { + Warn("Could not save reply: %v\n", err) + } } err = conversation.GenerateTitle() @@ -461,8 +466,9 @@ var retryCmd = &cobra.Command{ } var lastUserMessageIndex int + // walk backwards through conversations to find last user message for i := len(messages) - 1; i >= 0; i-- { - if messages[i].Role == "user" { + if messages[i].Role == MessageRoleUser { lastUserMessageIndex = i break } @@ -471,22 +477,22 @@ var retryCmd = &cobra.Command{ messages = messages[:lastUserMessageIndex+1] RenderConversation(messages, true) - assistantReply := Message{ - ConversationID: conversation.ID, - Role: "assistant", - } - assistantReply.RenderTTY() + (&Message{ + Role: MessageRoleAssistant, + }).RenderTTY() - response, err := LLMRequest(messages) + replies, err := LLMRequest(messages) if err != nil { Fatal("Error fetching LLM response: %v\n", err) } - assistantReply.OriginalContent = response + for _, reply := range replies { + reply.ConversationID = conversation.ID - err = store.SaveMessage(&assistantReply) - if err != nil { - Fatal("Could not save assistant reply: %v\n", err) + err = store.SaveMessage(&reply) + if err != nil { + Warn("Could not save reply: %v\n", err) + } } }, ValidArgsFunction: func(cmd *cobra.Command, args []string, toComplete string) ([]string, cobra.ShellCompDirective) { @@ -522,22 +528,22 @@ var continueCmd = &cobra.Command{ } RenderConversation(messages, true) - assistantReply := Message{ - ConversationID: conversation.ID, - Role: "assistant", - } - assistantReply.RenderTTY() + (&Message{ + Role: MessageRoleAssistant, + }).RenderTTY() - response, err := LLMRequest(messages) + replies, err := LLMRequest(messages) if err != nil { Fatal("Error fetching LLM response: %v\n", err) } - assistantReply.OriginalContent = response + for _, reply := range replies { + reply.ConversationID = conversation.ID - err = store.SaveMessage(&assistantReply) - if err != nil { - Fatal("Could not save assistant reply: %v\n", err) + err = store.SaveMessage(&reply) + if err != nil { + Warn("Could not save reply: %v\n", err) + } } }, ValidArgsFunction: func(cmd *cobra.Command, args []string, toComplete string) ([]string, cobra.ShellCompDirective) { diff --git a/pkg/cli/conversation.go b/pkg/cli/conversation.go index 7b9e786..197e7b8 100644 --- a/pkg/cli/conversation.go +++ b/pkg/cli/conversation.go @@ -5,18 +5,26 @@ import ( "strings" ) +type MessageRole string + +const ( + MessageRoleUser MessageRole = "user" + MessageRoleAssistant MessageRole = "assistant" + MessageRoleSystem MessageRole = "system" +) + // FriendlyRole returns a human friendly signifier for the message's role. func (m *Message) FriendlyRole() string { var friendlyRole string switch m.Role { - case "user": + case MessageRoleUser: friendlyRole = "You" - case "system": + case MessageRoleSystem: friendlyRole = "System" - case "assistant": + case MessageRoleAssistant: friendlyRole = "Assistant" default: - friendlyRole = m.Role + friendlyRole = string(m.Role) } return friendlyRole } @@ -27,13 +35,13 @@ func (c *Conversation) GenerateTitle() error { messages := []Message{ { - Role: "user", + Role: MessageRoleUser, OriginalContent: prompt, }, } model := "gpt-3.5-turbo" // use cheap model to generate title - response, err := CreateChatCompletion(model, messages, 25) + response, err := CreateChatCompletion(model, messages, 25, nil) if err != nil { return err } diff --git a/pkg/cli/openai.go b/pkg/cli/openai.go index c03cf30..c4f9843 100644 --- a/pkg/cli/openai.go +++ b/pkg/cli/openai.go @@ -16,7 +16,7 @@ func CreateChatCompletionRequest(model string, messages []Message, maxTokens int chatCompletionMessages := []openai.ChatCompletionMessage{} for _, m := range messages { message := openai.ChatCompletionMessage{ - Role: m.Role, + Role: string(m.Role), Content: m.OriginalContent, } if m.ToolCallID.Valid { @@ -53,7 +53,7 @@ func CreateChatCompletionRequest(model string, messages []Message, maxTokens int // CreateChatCompletion submits a Chat Completion API request and returns the // response. CreateChatCompletion will recursively call itself in the case of // tool calls, until a response is received with the final user-facing output. -func CreateChatCompletion(model string, messages []Message, maxTokens int) (string, error) { +func CreateChatCompletion(model string, messages []Message, maxTokens int, replies *[]Message) (string, error) { client := openai.NewClient(*config.OpenAI.APIKey) req := CreateChatCompletionRequest(model, messages, maxTokens) resp, err := client.CreateChatCompletion(context.Background(), req) @@ -64,25 +64,33 @@ func CreateChatCompletion(model string, messages []Message, maxTokens int) (stri choice := resp.Choices[0] if len(choice.Message.ToolCalls) > 0 { - if choice.Message.Content != "" { - return "", fmt.Errorf("Model replied with user-facing content in addition to tool calls. Unsupported.") - } - // Append the assistant's reply with its request for tool calls toolCallJson, _ := json.Marshal(choice.Message.ToolCalls) - messages = append(messages, Message{ - Role: "assistant", + assistantReply := Message{ + Role: MessageRoleAssistant, ToolCalls: sql.NullString{String: string(toolCallJson), Valid: true}, - }) + } toolReplies, err := ExecuteToolCalls(choice.Message.ToolCalls) if err != nil { return "", err } + if replies != nil { + *replies = append(append(*replies, assistantReply), toolReplies...) + } + + messages = append(append(messages, assistantReply), toolReplies...) // Recurse into CreateChatCompletion with the tool call replies added // to the original messages - return CreateChatCompletion(model, append(messages, toolReplies...), maxTokens) + return CreateChatCompletion(model, append(messages, toolReplies...), maxTokens, replies) + } + + if replies != nil { + *replies = append(*replies, Message{ + Role: MessageRoleAssistant, + OriginalContent: choice.Message.Content, + }) } // Return the user-facing message. @@ -92,7 +100,7 @@ func CreateChatCompletion(model string, messages []Message, maxTokens int) (stri // CreateChatCompletionStream submits a streaming Chat Completion API request // and both returns and streams the response to the provided output channel. // May return a partial response if an error occurs mid-stream. -func CreateChatCompletionStream(model string, messages []Message, maxTokens int, output chan<- string) (string, error) { +func CreateChatCompletionStream(model string, messages []Message, maxTokens int, output chan<- string, replies *[]Message) (string, error) { client := openai.NewClient(*config.OpenAI.APIKey) req := CreateChatCompletionRequest(model, messages, maxTokens) @@ -137,25 +145,35 @@ func CreateChatCompletionStream(model string, messages []Message, maxTokens int, } if len(toolCalls) > 0 { - if content.String() != "" { - return "", fmt.Errorf("Model replied with user-facing content in addition to tool calls. Unsupported.") - } - // Append the assistant's reply with its request for tool calls toolCallJson, _ := json.Marshal(toolCalls) - messages = append(messages, Message{ - Role: "assistant", - ToolCalls: sql.NullString{String: string(toolCallJson), Valid: true}, - }) + + assistantReply := Message{ + Role: MessageRoleAssistant, + OriginalContent: content.String(), + ToolCalls: sql.NullString{String: string(toolCallJson), Valid: true}, + } toolReplies, err := ExecuteToolCalls(toolCalls) if err != nil { return "", err } + if replies != nil { + *replies = append(append(*replies, assistantReply), toolReplies...) + } + // Recurse into CreateChatCompletionStream with the tool call replies // added to the original messages - return CreateChatCompletionStream(model, append(messages, toolReplies...), maxTokens, output) + messages = append(append(messages, assistantReply), toolReplies...) + return CreateChatCompletionStream(model, messages, maxTokens, output, replies) + } + + if replies != nil { + *replies = append(*replies, Message{ + Role: MessageRoleAssistant, + OriginalContent: content.String(), + }) } return content.String(), err diff --git a/pkg/cli/store.go b/pkg/cli/store.go index 2ed7909..41f9795 100644 --- a/pkg/cli/store.go +++ b/pkg/cli/store.go @@ -23,7 +23,7 @@ type Message struct { ConversationID uint `gorm:"foreignKey:ConversationID"` Conversation Conversation OriginalContent string - Role string // one of: 'user', 'assistant', 'tool' + Role MessageRole // one of: 'system', 'user', 'assistant', 'tool' CreatedAt time.Time ToolCallID sql.NullString ToolCalls sql.NullString // a json-encoded array of tool calls from the model