Add *Message[] parameter to CreateChatCompletion methods

Allows replies (tool calls, user-facing messges) to be added in sequence
as CreateChatCompleion* recurses into itself.

Cleaned up cmd.go: no longer need to create a Message based on the
string content response.
This commit is contained in:
Matt Low 2023-11-29 04:43:53 +00:00
parent e850c340b7
commit ed6ee9bea9
4 changed files with 105 additions and 73 deletions

View File

@ -53,9 +53,10 @@ func SystemPrompt() string {
return systemPrompt
}
// LLMRequest prompts the LLM with the given Message, writes the (partial)
// response to stdout, and returns the (partial) response or any errors.
func LLMRequest(messages []Message) (string, error) {
// LLMRequest prompts the LLM with the given messages, writing the response
// to stdout. Returns all reply messages added by the LLM, including any
// function call messages.
func LLMRequest(messages []Message) ([]Message, error) {
// receiver receives the reponse from LLM
receiver := make(chan string)
defer close(receiver)
@ -63,7 +64,8 @@ func LLMRequest(messages []Message) (string, error) {
// start HandleDelayedContent goroutine to print received data to stdout
go HandleDelayedContent(receiver)
response, err := CreateChatCompletionStream(model, messages, maxTokens, receiver)
var replies []Message
response, err := CreateChatCompletionStream(model, messages, maxTokens, receiver, &replies)
if response != "" {
if err != nil {
Warn("Received partial response. Error: %v\n", err)
@ -73,7 +75,7 @@ func LLMRequest(messages []Message) (string, error) {
fmt.Println()
}
return response, err
return replies, err
}
// InputFromArgsOrEditor returns either the provided input from the args slice
@ -316,22 +318,23 @@ var replyCmd = &cobra.Command{
messages = append(messages, userReply)
RenderConversation(messages, true)
assistantReply := Message{
ConversationID: conversation.ID,
Role: "assistant",
}
assistantReply.RenderTTY()
// output an <Assistant> message heading
(&Message{
Role: MessageRoleAssistant,
}).RenderTTY()
response, err := LLMRequest(messages)
replies, err := LLMRequest(messages)
if err != nil {
Fatal("Error fetching LLM response: %v\n", err)
}
assistantReply.OriginalContent = response
for _, reply := range replies {
reply.ConversationID = conversation.ID
err = store.SaveMessage(&assistantReply)
err = store.SaveMessage(&reply)
if err != nil {
Fatal("Could not save assistant reply: %v\n", err)
Warn("Could not save reply: %v\n", err)
}
}
},
ValidArgsFunction: func(cmd *cobra.Command, args []string, toComplete string) ([]string, cobra.ShellCompDirective) {
@ -379,22 +382,24 @@ var newCmd = &cobra.Command{
}
RenderConversation(messages, true)
reply := Message{
ConversationID: conversation.ID,
Role: "assistant",
}
reply.RenderTTY()
response, err := LLMRequest(messages)
// output an <Assistant> message heading
(&Message{
Role: MessageRoleAssistant,
}).RenderTTY()
replies, err := LLMRequest(messages)
if err != nil {
Fatal("Error fetching LLM response: %v\n", err)
}
reply.OriginalContent = response
for _, reply := range replies {
reply.ConversationID = conversation.ID
err = store.SaveMessage(&reply)
if err != nil {
Fatal("Could not save reply: %v\n", err)
Warn("Could not save reply: %v\n", err)
}
}
err = conversation.GenerateTitle()
@ -461,8 +466,9 @@ var retryCmd = &cobra.Command{
}
var lastUserMessageIndex int
// walk backwards through conversations to find last user message
for i := len(messages) - 1; i >= 0; i-- {
if messages[i].Role == "user" {
if messages[i].Role == MessageRoleUser {
lastUserMessageIndex = i
break
}
@ -471,22 +477,22 @@ var retryCmd = &cobra.Command{
messages = messages[:lastUserMessageIndex+1]
RenderConversation(messages, true)
assistantReply := Message{
ConversationID: conversation.ID,
Role: "assistant",
}
assistantReply.RenderTTY()
(&Message{
Role: MessageRoleAssistant,
}).RenderTTY()
response, err := LLMRequest(messages)
replies, err := LLMRequest(messages)
if err != nil {
Fatal("Error fetching LLM response: %v\n", err)
}
assistantReply.OriginalContent = response
for _, reply := range replies {
reply.ConversationID = conversation.ID
err = store.SaveMessage(&assistantReply)
err = store.SaveMessage(&reply)
if err != nil {
Fatal("Could not save assistant reply: %v\n", err)
Warn("Could not save reply: %v\n", err)
}
}
},
ValidArgsFunction: func(cmd *cobra.Command, args []string, toComplete string) ([]string, cobra.ShellCompDirective) {
@ -522,22 +528,22 @@ var continueCmd = &cobra.Command{
}
RenderConversation(messages, true)
assistantReply := Message{
ConversationID: conversation.ID,
Role: "assistant",
}
assistantReply.RenderTTY()
(&Message{
Role: MessageRoleAssistant,
}).RenderTTY()
response, err := LLMRequest(messages)
replies, err := LLMRequest(messages)
if err != nil {
Fatal("Error fetching LLM response: %v\n", err)
}
assistantReply.OriginalContent = response
for _, reply := range replies {
reply.ConversationID = conversation.ID
err = store.SaveMessage(&assistantReply)
err = store.SaveMessage(&reply)
if err != nil {
Fatal("Could not save assistant reply: %v\n", err)
Warn("Could not save reply: %v\n", err)
}
}
},
ValidArgsFunction: func(cmd *cobra.Command, args []string, toComplete string) ([]string, cobra.ShellCompDirective) {

View File

@ -5,18 +5,26 @@ import (
"strings"
)
type MessageRole string
const (
MessageRoleUser MessageRole = "user"
MessageRoleAssistant MessageRole = "assistant"
MessageRoleSystem MessageRole = "system"
)
// FriendlyRole returns a human friendly signifier for the message's role.
func (m *Message) FriendlyRole() string {
var friendlyRole string
switch m.Role {
case "user":
case MessageRoleUser:
friendlyRole = "You"
case "system":
case MessageRoleSystem:
friendlyRole = "System"
case "assistant":
case MessageRoleAssistant:
friendlyRole = "Assistant"
default:
friendlyRole = m.Role
friendlyRole = string(m.Role)
}
return friendlyRole
}
@ -27,13 +35,13 @@ func (c *Conversation) GenerateTitle() error {
messages := []Message{
{
Role: "user",
Role: MessageRoleUser,
OriginalContent: prompt,
},
}
model := "gpt-3.5-turbo" // use cheap model to generate title
response, err := CreateChatCompletion(model, messages, 25)
response, err := CreateChatCompletion(model, messages, 25, nil)
if err != nil {
return err
}

View File

@ -16,7 +16,7 @@ func CreateChatCompletionRequest(model string, messages []Message, maxTokens int
chatCompletionMessages := []openai.ChatCompletionMessage{}
for _, m := range messages {
message := openai.ChatCompletionMessage{
Role: m.Role,
Role: string(m.Role),
Content: m.OriginalContent,
}
if m.ToolCallID.Valid {
@ -53,7 +53,7 @@ func CreateChatCompletionRequest(model string, messages []Message, maxTokens int
// CreateChatCompletion submits a Chat Completion API request and returns the
// response. CreateChatCompletion will recursively call itself in the case of
// tool calls, until a response is received with the final user-facing output.
func CreateChatCompletion(model string, messages []Message, maxTokens int) (string, error) {
func CreateChatCompletion(model string, messages []Message, maxTokens int, replies *[]Message) (string, error) {
client := openai.NewClient(*config.OpenAI.APIKey)
req := CreateChatCompletionRequest(model, messages, maxTokens)
resp, err := client.CreateChatCompletion(context.Background(), req)
@ -64,25 +64,33 @@ func CreateChatCompletion(model string, messages []Message, maxTokens int) (stri
choice := resp.Choices[0]
if len(choice.Message.ToolCalls) > 0 {
if choice.Message.Content != "" {
return "", fmt.Errorf("Model replied with user-facing content in addition to tool calls. Unsupported.")
}
// Append the assistant's reply with its request for tool calls
toolCallJson, _ := json.Marshal(choice.Message.ToolCalls)
messages = append(messages, Message{
Role: "assistant",
assistantReply := Message{
Role: MessageRoleAssistant,
ToolCalls: sql.NullString{String: string(toolCallJson), Valid: true},
})
}
toolReplies, err := ExecuteToolCalls(choice.Message.ToolCalls)
if err != nil {
return "", err
}
if replies != nil {
*replies = append(append(*replies, assistantReply), toolReplies...)
}
messages = append(append(messages, assistantReply), toolReplies...)
// Recurse into CreateChatCompletion with the tool call replies added
// to the original messages
return CreateChatCompletion(model, append(messages, toolReplies...), maxTokens)
return CreateChatCompletion(model, append(messages, toolReplies...), maxTokens, replies)
}
if replies != nil {
*replies = append(*replies, Message{
Role: MessageRoleAssistant,
OriginalContent: choice.Message.Content,
})
}
// Return the user-facing message.
@ -92,7 +100,7 @@ func CreateChatCompletion(model string, messages []Message, maxTokens int) (stri
// CreateChatCompletionStream submits a streaming Chat Completion API request
// and both returns and streams the response to the provided output channel.
// May return a partial response if an error occurs mid-stream.
func CreateChatCompletionStream(model string, messages []Message, maxTokens int, output chan<- string) (string, error) {
func CreateChatCompletionStream(model string, messages []Message, maxTokens int, output chan<- string, replies *[]Message) (string, error) {
client := openai.NewClient(*config.OpenAI.APIKey)
req := CreateChatCompletionRequest(model, messages, maxTokens)
@ -137,25 +145,35 @@ func CreateChatCompletionStream(model string, messages []Message, maxTokens int,
}
if len(toolCalls) > 0 {
if content.String() != "" {
return "", fmt.Errorf("Model replied with user-facing content in addition to tool calls. Unsupported.")
}
// Append the assistant's reply with its request for tool calls
toolCallJson, _ := json.Marshal(toolCalls)
messages = append(messages, Message{
Role: "assistant",
assistantReply := Message{
Role: MessageRoleAssistant,
OriginalContent: content.String(),
ToolCalls: sql.NullString{String: string(toolCallJson), Valid: true},
})
}
toolReplies, err := ExecuteToolCalls(toolCalls)
if err != nil {
return "", err
}
if replies != nil {
*replies = append(append(*replies, assistantReply), toolReplies...)
}
// Recurse into CreateChatCompletionStream with the tool call replies
// added to the original messages
return CreateChatCompletionStream(model, append(messages, toolReplies...), maxTokens, output)
messages = append(append(messages, assistantReply), toolReplies...)
return CreateChatCompletionStream(model, messages, maxTokens, output, replies)
}
if replies != nil {
*replies = append(*replies, Message{
Role: MessageRoleAssistant,
OriginalContent: content.String(),
})
}
return content.String(), err

View File

@ -23,7 +23,7 @@ type Message struct {
ConversationID uint `gorm:"foreignKey:ConversationID"`
Conversation Conversation
OriginalContent string
Role string // one of: 'user', 'assistant', 'tool'
Role MessageRole // one of: 'system', 'user', 'assistant', 'tool'
CreatedAt time.Time
ToolCallID sql.NullString
ToolCalls sql.NullString // a json-encoded array of tool calls from the model