Add *Message[] parameter to CreateChatCompletion methods

Allows replies (tool calls, user-facing messges) to be added in sequence
as CreateChatCompleion* recurses into itself.

Cleaned up cmd.go: no longer need to create a Message based on the
string content response.
This commit is contained in:
Matt Low 2023-11-29 04:43:53 +00:00
parent e850c340b7
commit ed6ee9bea9
4 changed files with 105 additions and 73 deletions

View File

@ -53,9 +53,10 @@ func SystemPrompt() string {
return systemPrompt return systemPrompt
} }
// LLMRequest prompts the LLM with the given Message, writes the (partial) // LLMRequest prompts the LLM with the given messages, writing the response
// response to stdout, and returns the (partial) response or any errors. // to stdout. Returns all reply messages added by the LLM, including any
func LLMRequest(messages []Message) (string, error) { // function call messages.
func LLMRequest(messages []Message) ([]Message, error) {
// receiver receives the reponse from LLM // receiver receives the reponse from LLM
receiver := make(chan string) receiver := make(chan string)
defer close(receiver) defer close(receiver)
@ -63,7 +64,8 @@ func LLMRequest(messages []Message) (string, error) {
// start HandleDelayedContent goroutine to print received data to stdout // start HandleDelayedContent goroutine to print received data to stdout
go HandleDelayedContent(receiver) go HandleDelayedContent(receiver)
response, err := CreateChatCompletionStream(model, messages, maxTokens, receiver) var replies []Message
response, err := CreateChatCompletionStream(model, messages, maxTokens, receiver, &replies)
if response != "" { if response != "" {
if err != nil { if err != nil {
Warn("Received partial response. Error: %v\n", err) Warn("Received partial response. Error: %v\n", err)
@ -73,7 +75,7 @@ func LLMRequest(messages []Message) (string, error) {
fmt.Println() fmt.Println()
} }
return response, err return replies, err
} }
// InputFromArgsOrEditor returns either the provided input from the args slice // InputFromArgsOrEditor returns either the provided input from the args slice
@ -316,22 +318,23 @@ var replyCmd = &cobra.Command{
messages = append(messages, userReply) messages = append(messages, userReply)
RenderConversation(messages, true) RenderConversation(messages, true)
assistantReply := Message{ // output an <Assistant> message heading
ConversationID: conversation.ID, (&Message{
Role: "assistant", Role: MessageRoleAssistant,
} }).RenderTTY()
assistantReply.RenderTTY()
response, err := LLMRequest(messages) replies, err := LLMRequest(messages)
if err != nil { if err != nil {
Fatal("Error fetching LLM response: %v\n", err) Fatal("Error fetching LLM response: %v\n", err)
} }
assistantReply.OriginalContent = response for _, reply := range replies {
reply.ConversationID = conversation.ID
err = store.SaveMessage(&assistantReply) err = store.SaveMessage(&reply)
if err != nil { if err != nil {
Fatal("Could not save assistant reply: %v\n", err) Warn("Could not save reply: %v\n", err)
}
} }
}, },
ValidArgsFunction: func(cmd *cobra.Command, args []string, toComplete string) ([]string, cobra.ShellCompDirective) { ValidArgsFunction: func(cmd *cobra.Command, args []string, toComplete string) ([]string, cobra.ShellCompDirective) {
@ -379,22 +382,24 @@ var newCmd = &cobra.Command{
} }
RenderConversation(messages, true) RenderConversation(messages, true)
reply := Message{
ConversationID: conversation.ID,
Role: "assistant",
}
reply.RenderTTY()
response, err := LLMRequest(messages) // output an <Assistant> message heading
(&Message{
Role: MessageRoleAssistant,
}).RenderTTY()
replies, err := LLMRequest(messages)
if err != nil { if err != nil {
Fatal("Error fetching LLM response: %v\n", err) Fatal("Error fetching LLM response: %v\n", err)
} }
reply.OriginalContent = response for _, reply := range replies {
reply.ConversationID = conversation.ID
err = store.SaveMessage(&reply) err = store.SaveMessage(&reply)
if err != nil { if err != nil {
Fatal("Could not save reply: %v\n", err) Warn("Could not save reply: %v\n", err)
}
} }
err = conversation.GenerateTitle() err = conversation.GenerateTitle()
@ -461,8 +466,9 @@ var retryCmd = &cobra.Command{
} }
var lastUserMessageIndex int var lastUserMessageIndex int
// walk backwards through conversations to find last user message
for i := len(messages) - 1; i >= 0; i-- { for i := len(messages) - 1; i >= 0; i-- {
if messages[i].Role == "user" { if messages[i].Role == MessageRoleUser {
lastUserMessageIndex = i lastUserMessageIndex = i
break break
} }
@ -471,22 +477,22 @@ var retryCmd = &cobra.Command{
messages = messages[:lastUserMessageIndex+1] messages = messages[:lastUserMessageIndex+1]
RenderConversation(messages, true) RenderConversation(messages, true)
assistantReply := Message{ (&Message{
ConversationID: conversation.ID, Role: MessageRoleAssistant,
Role: "assistant", }).RenderTTY()
}
assistantReply.RenderTTY()
response, err := LLMRequest(messages) replies, err := LLMRequest(messages)
if err != nil { if err != nil {
Fatal("Error fetching LLM response: %v\n", err) Fatal("Error fetching LLM response: %v\n", err)
} }
assistantReply.OriginalContent = response for _, reply := range replies {
reply.ConversationID = conversation.ID
err = store.SaveMessage(&assistantReply) err = store.SaveMessage(&reply)
if err != nil { if err != nil {
Fatal("Could not save assistant reply: %v\n", err) Warn("Could not save reply: %v\n", err)
}
} }
}, },
ValidArgsFunction: func(cmd *cobra.Command, args []string, toComplete string) ([]string, cobra.ShellCompDirective) { ValidArgsFunction: func(cmd *cobra.Command, args []string, toComplete string) ([]string, cobra.ShellCompDirective) {
@ -522,22 +528,22 @@ var continueCmd = &cobra.Command{
} }
RenderConversation(messages, true) RenderConversation(messages, true)
assistantReply := Message{ (&Message{
ConversationID: conversation.ID, Role: MessageRoleAssistant,
Role: "assistant", }).RenderTTY()
}
assistantReply.RenderTTY()
response, err := LLMRequest(messages) replies, err := LLMRequest(messages)
if err != nil { if err != nil {
Fatal("Error fetching LLM response: %v\n", err) Fatal("Error fetching LLM response: %v\n", err)
} }
assistantReply.OriginalContent = response for _, reply := range replies {
reply.ConversationID = conversation.ID
err = store.SaveMessage(&assistantReply) err = store.SaveMessage(&reply)
if err != nil { if err != nil {
Fatal("Could not save assistant reply: %v\n", err) Warn("Could not save reply: %v\n", err)
}
} }
}, },
ValidArgsFunction: func(cmd *cobra.Command, args []string, toComplete string) ([]string, cobra.ShellCompDirective) { ValidArgsFunction: func(cmd *cobra.Command, args []string, toComplete string) ([]string, cobra.ShellCompDirective) {

View File

@ -5,18 +5,26 @@ import (
"strings" "strings"
) )
type MessageRole string
const (
MessageRoleUser MessageRole = "user"
MessageRoleAssistant MessageRole = "assistant"
MessageRoleSystem MessageRole = "system"
)
// FriendlyRole returns a human friendly signifier for the message's role. // FriendlyRole returns a human friendly signifier for the message's role.
func (m *Message) FriendlyRole() string { func (m *Message) FriendlyRole() string {
var friendlyRole string var friendlyRole string
switch m.Role { switch m.Role {
case "user": case MessageRoleUser:
friendlyRole = "You" friendlyRole = "You"
case "system": case MessageRoleSystem:
friendlyRole = "System" friendlyRole = "System"
case "assistant": case MessageRoleAssistant:
friendlyRole = "Assistant" friendlyRole = "Assistant"
default: default:
friendlyRole = m.Role friendlyRole = string(m.Role)
} }
return friendlyRole return friendlyRole
} }
@ -27,13 +35,13 @@ func (c *Conversation) GenerateTitle() error {
messages := []Message{ messages := []Message{
{ {
Role: "user", Role: MessageRoleUser,
OriginalContent: prompt, OriginalContent: prompt,
}, },
} }
model := "gpt-3.5-turbo" // use cheap model to generate title model := "gpt-3.5-turbo" // use cheap model to generate title
response, err := CreateChatCompletion(model, messages, 25) response, err := CreateChatCompletion(model, messages, 25, nil)
if err != nil { if err != nil {
return err return err
} }

View File

@ -16,7 +16,7 @@ func CreateChatCompletionRequest(model string, messages []Message, maxTokens int
chatCompletionMessages := []openai.ChatCompletionMessage{} chatCompletionMessages := []openai.ChatCompletionMessage{}
for _, m := range messages { for _, m := range messages {
message := openai.ChatCompletionMessage{ message := openai.ChatCompletionMessage{
Role: m.Role, Role: string(m.Role),
Content: m.OriginalContent, Content: m.OriginalContent,
} }
if m.ToolCallID.Valid { if m.ToolCallID.Valid {
@ -53,7 +53,7 @@ func CreateChatCompletionRequest(model string, messages []Message, maxTokens int
// CreateChatCompletion submits a Chat Completion API request and returns the // CreateChatCompletion submits a Chat Completion API request and returns the
// response. CreateChatCompletion will recursively call itself in the case of // response. CreateChatCompletion will recursively call itself in the case of
// tool calls, until a response is received with the final user-facing output. // tool calls, until a response is received with the final user-facing output.
func CreateChatCompletion(model string, messages []Message, maxTokens int) (string, error) { func CreateChatCompletion(model string, messages []Message, maxTokens int, replies *[]Message) (string, error) {
client := openai.NewClient(*config.OpenAI.APIKey) client := openai.NewClient(*config.OpenAI.APIKey)
req := CreateChatCompletionRequest(model, messages, maxTokens) req := CreateChatCompletionRequest(model, messages, maxTokens)
resp, err := client.CreateChatCompletion(context.Background(), req) resp, err := client.CreateChatCompletion(context.Background(), req)
@ -64,25 +64,33 @@ func CreateChatCompletion(model string, messages []Message, maxTokens int) (stri
choice := resp.Choices[0] choice := resp.Choices[0]
if len(choice.Message.ToolCalls) > 0 { if len(choice.Message.ToolCalls) > 0 {
if choice.Message.Content != "" {
return "", fmt.Errorf("Model replied with user-facing content in addition to tool calls. Unsupported.")
}
// Append the assistant's reply with its request for tool calls // Append the assistant's reply with its request for tool calls
toolCallJson, _ := json.Marshal(choice.Message.ToolCalls) toolCallJson, _ := json.Marshal(choice.Message.ToolCalls)
messages = append(messages, Message{ assistantReply := Message{
Role: "assistant", Role: MessageRoleAssistant,
ToolCalls: sql.NullString{String: string(toolCallJson), Valid: true}, ToolCalls: sql.NullString{String: string(toolCallJson), Valid: true},
}) }
toolReplies, err := ExecuteToolCalls(choice.Message.ToolCalls) toolReplies, err := ExecuteToolCalls(choice.Message.ToolCalls)
if err != nil { if err != nil {
return "", err return "", err
} }
if replies != nil {
*replies = append(append(*replies, assistantReply), toolReplies...)
}
messages = append(append(messages, assistantReply), toolReplies...)
// Recurse into CreateChatCompletion with the tool call replies added // Recurse into CreateChatCompletion with the tool call replies added
// to the original messages // to the original messages
return CreateChatCompletion(model, append(messages, toolReplies...), maxTokens) return CreateChatCompletion(model, append(messages, toolReplies...), maxTokens, replies)
}
if replies != nil {
*replies = append(*replies, Message{
Role: MessageRoleAssistant,
OriginalContent: choice.Message.Content,
})
} }
// Return the user-facing message. // Return the user-facing message.
@ -92,7 +100,7 @@ func CreateChatCompletion(model string, messages []Message, maxTokens int) (stri
// CreateChatCompletionStream submits a streaming Chat Completion API request // CreateChatCompletionStream submits a streaming Chat Completion API request
// and both returns and streams the response to the provided output channel. // and both returns and streams the response to the provided output channel.
// May return a partial response if an error occurs mid-stream. // May return a partial response if an error occurs mid-stream.
func CreateChatCompletionStream(model string, messages []Message, maxTokens int, output chan<- string) (string, error) { func CreateChatCompletionStream(model string, messages []Message, maxTokens int, output chan<- string, replies *[]Message) (string, error) {
client := openai.NewClient(*config.OpenAI.APIKey) client := openai.NewClient(*config.OpenAI.APIKey)
req := CreateChatCompletionRequest(model, messages, maxTokens) req := CreateChatCompletionRequest(model, messages, maxTokens)
@ -137,25 +145,35 @@ func CreateChatCompletionStream(model string, messages []Message, maxTokens int,
} }
if len(toolCalls) > 0 { if len(toolCalls) > 0 {
if content.String() != "" {
return "", fmt.Errorf("Model replied with user-facing content in addition to tool calls. Unsupported.")
}
// Append the assistant's reply with its request for tool calls // Append the assistant's reply with its request for tool calls
toolCallJson, _ := json.Marshal(toolCalls) toolCallJson, _ := json.Marshal(toolCalls)
messages = append(messages, Message{
Role: "assistant", assistantReply := Message{
ToolCalls: sql.NullString{String: string(toolCallJson), Valid: true}, Role: MessageRoleAssistant,
}) OriginalContent: content.String(),
ToolCalls: sql.NullString{String: string(toolCallJson), Valid: true},
}
toolReplies, err := ExecuteToolCalls(toolCalls) toolReplies, err := ExecuteToolCalls(toolCalls)
if err != nil { if err != nil {
return "", err return "", err
} }
if replies != nil {
*replies = append(append(*replies, assistantReply), toolReplies...)
}
// Recurse into CreateChatCompletionStream with the tool call replies // Recurse into CreateChatCompletionStream with the tool call replies
// added to the original messages // added to the original messages
return CreateChatCompletionStream(model, append(messages, toolReplies...), maxTokens, output) messages = append(append(messages, assistantReply), toolReplies...)
return CreateChatCompletionStream(model, messages, maxTokens, output, replies)
}
if replies != nil {
*replies = append(*replies, Message{
Role: MessageRoleAssistant,
OriginalContent: content.String(),
})
} }
return content.String(), err return content.String(), err

View File

@ -23,7 +23,7 @@ type Message struct {
ConversationID uint `gorm:"foreignKey:ConversationID"` ConversationID uint `gorm:"foreignKey:ConversationID"`
Conversation Conversation Conversation Conversation
OriginalContent string OriginalContent string
Role string // one of: 'user', 'assistant', 'tool' Role MessageRole // one of: 'system', 'user', 'assistant', 'tool'
CreatedAt time.Time CreatedAt time.Time
ToolCallID sql.NullString ToolCallID sql.NullString
ToolCalls sql.NullString // a json-encoded array of tool calls from the model ToolCalls sql.NullString // a json-encoded array of tool calls from the model