Add *Message[] parameter to CreateChatCompletion methods
Allows replies (tool calls, user-facing messges) to be added in sequence as CreateChatCompleion* recurses into itself. Cleaned up cmd.go: no longer need to create a Message based on the string content response.
This commit is contained in:
parent
e850c340b7
commit
ed6ee9bea9
@ -53,9 +53,10 @@ func SystemPrompt() string {
|
||||
return systemPrompt
|
||||
}
|
||||
|
||||
// LLMRequest prompts the LLM with the given Message, writes the (partial)
|
||||
// response to stdout, and returns the (partial) response or any errors.
|
||||
func LLMRequest(messages []Message) (string, error) {
|
||||
// LLMRequest prompts the LLM with the given messages, writing the response
|
||||
// to stdout. Returns all reply messages added by the LLM, including any
|
||||
// function call messages.
|
||||
func LLMRequest(messages []Message) ([]Message, error) {
|
||||
// receiver receives the reponse from LLM
|
||||
receiver := make(chan string)
|
||||
defer close(receiver)
|
||||
@ -63,7 +64,8 @@ func LLMRequest(messages []Message) (string, error) {
|
||||
// start HandleDelayedContent goroutine to print received data to stdout
|
||||
go HandleDelayedContent(receiver)
|
||||
|
||||
response, err := CreateChatCompletionStream(model, messages, maxTokens, receiver)
|
||||
var replies []Message
|
||||
response, err := CreateChatCompletionStream(model, messages, maxTokens, receiver, &replies)
|
||||
if response != "" {
|
||||
if err != nil {
|
||||
Warn("Received partial response. Error: %v\n", err)
|
||||
@ -73,7 +75,7 @@ func LLMRequest(messages []Message) (string, error) {
|
||||
fmt.Println()
|
||||
}
|
||||
|
||||
return response, err
|
||||
return replies, err
|
||||
}
|
||||
|
||||
// InputFromArgsOrEditor returns either the provided input from the args slice
|
||||
@ -316,22 +318,23 @@ var replyCmd = &cobra.Command{
|
||||
messages = append(messages, userReply)
|
||||
|
||||
RenderConversation(messages, true)
|
||||
assistantReply := Message{
|
||||
ConversationID: conversation.ID,
|
||||
Role: "assistant",
|
||||
}
|
||||
assistantReply.RenderTTY()
|
||||
// output an <Assistant> message heading
|
||||
(&Message{
|
||||
Role: MessageRoleAssistant,
|
||||
}).RenderTTY()
|
||||
|
||||
response, err := LLMRequest(messages)
|
||||
replies, err := LLMRequest(messages)
|
||||
if err != nil {
|
||||
Fatal("Error fetching LLM response: %v\n", err)
|
||||
}
|
||||
|
||||
assistantReply.OriginalContent = response
|
||||
for _, reply := range replies {
|
||||
reply.ConversationID = conversation.ID
|
||||
|
||||
err = store.SaveMessage(&assistantReply)
|
||||
if err != nil {
|
||||
Fatal("Could not save assistant reply: %v\n", err)
|
||||
err = store.SaveMessage(&reply)
|
||||
if err != nil {
|
||||
Warn("Could not save reply: %v\n", err)
|
||||
}
|
||||
}
|
||||
},
|
||||
ValidArgsFunction: func(cmd *cobra.Command, args []string, toComplete string) ([]string, cobra.ShellCompDirective) {
|
||||
@ -379,22 +382,24 @@ var newCmd = &cobra.Command{
|
||||
}
|
||||
|
||||
RenderConversation(messages, true)
|
||||
reply := Message{
|
||||
ConversationID: conversation.ID,
|
||||
Role: "assistant",
|
||||
}
|
||||
reply.RenderTTY()
|
||||
|
||||
response, err := LLMRequest(messages)
|
||||
// output an <Assistant> message heading
|
||||
(&Message{
|
||||
Role: MessageRoleAssistant,
|
||||
}).RenderTTY()
|
||||
|
||||
replies, err := LLMRequest(messages)
|
||||
if err != nil {
|
||||
Fatal("Error fetching LLM response: %v\n", err)
|
||||
}
|
||||
|
||||
reply.OriginalContent = response
|
||||
for _, reply := range replies {
|
||||
reply.ConversationID = conversation.ID
|
||||
|
||||
err = store.SaveMessage(&reply)
|
||||
if err != nil {
|
||||
Fatal("Could not save reply: %v\n", err)
|
||||
err = store.SaveMessage(&reply)
|
||||
if err != nil {
|
||||
Warn("Could not save reply: %v\n", err)
|
||||
}
|
||||
}
|
||||
|
||||
err = conversation.GenerateTitle()
|
||||
@ -461,8 +466,9 @@ var retryCmd = &cobra.Command{
|
||||
}
|
||||
|
||||
var lastUserMessageIndex int
|
||||
// walk backwards through conversations to find last user message
|
||||
for i := len(messages) - 1; i >= 0; i-- {
|
||||
if messages[i].Role == "user" {
|
||||
if messages[i].Role == MessageRoleUser {
|
||||
lastUserMessageIndex = i
|
||||
break
|
||||
}
|
||||
@ -471,22 +477,22 @@ var retryCmd = &cobra.Command{
|
||||
messages = messages[:lastUserMessageIndex+1]
|
||||
|
||||
RenderConversation(messages, true)
|
||||
assistantReply := Message{
|
||||
ConversationID: conversation.ID,
|
||||
Role: "assistant",
|
||||
}
|
||||
assistantReply.RenderTTY()
|
||||
(&Message{
|
||||
Role: MessageRoleAssistant,
|
||||
}).RenderTTY()
|
||||
|
||||
response, err := LLMRequest(messages)
|
||||
replies, err := LLMRequest(messages)
|
||||
if err != nil {
|
||||
Fatal("Error fetching LLM response: %v\n", err)
|
||||
}
|
||||
|
||||
assistantReply.OriginalContent = response
|
||||
for _, reply := range replies {
|
||||
reply.ConversationID = conversation.ID
|
||||
|
||||
err = store.SaveMessage(&assistantReply)
|
||||
if err != nil {
|
||||
Fatal("Could not save assistant reply: %v\n", err)
|
||||
err = store.SaveMessage(&reply)
|
||||
if err != nil {
|
||||
Warn("Could not save reply: %v\n", err)
|
||||
}
|
||||
}
|
||||
},
|
||||
ValidArgsFunction: func(cmd *cobra.Command, args []string, toComplete string) ([]string, cobra.ShellCompDirective) {
|
||||
@ -522,22 +528,22 @@ var continueCmd = &cobra.Command{
|
||||
}
|
||||
|
||||
RenderConversation(messages, true)
|
||||
assistantReply := Message{
|
||||
ConversationID: conversation.ID,
|
||||
Role: "assistant",
|
||||
}
|
||||
assistantReply.RenderTTY()
|
||||
(&Message{
|
||||
Role: MessageRoleAssistant,
|
||||
}).RenderTTY()
|
||||
|
||||
response, err := LLMRequest(messages)
|
||||
replies, err := LLMRequest(messages)
|
||||
if err != nil {
|
||||
Fatal("Error fetching LLM response: %v\n", err)
|
||||
}
|
||||
|
||||
assistantReply.OriginalContent = response
|
||||
for _, reply := range replies {
|
||||
reply.ConversationID = conversation.ID
|
||||
|
||||
err = store.SaveMessage(&assistantReply)
|
||||
if err != nil {
|
||||
Fatal("Could not save assistant reply: %v\n", err)
|
||||
err = store.SaveMessage(&reply)
|
||||
if err != nil {
|
||||
Warn("Could not save reply: %v\n", err)
|
||||
}
|
||||
}
|
||||
},
|
||||
ValidArgsFunction: func(cmd *cobra.Command, args []string, toComplete string) ([]string, cobra.ShellCompDirective) {
|
||||
|
@ -5,18 +5,26 @@ import (
|
||||
"strings"
|
||||
)
|
||||
|
||||
type MessageRole string
|
||||
|
||||
const (
|
||||
MessageRoleUser MessageRole = "user"
|
||||
MessageRoleAssistant MessageRole = "assistant"
|
||||
MessageRoleSystem MessageRole = "system"
|
||||
)
|
||||
|
||||
// FriendlyRole returns a human friendly signifier for the message's role.
|
||||
func (m *Message) FriendlyRole() string {
|
||||
var friendlyRole string
|
||||
switch m.Role {
|
||||
case "user":
|
||||
case MessageRoleUser:
|
||||
friendlyRole = "You"
|
||||
case "system":
|
||||
case MessageRoleSystem:
|
||||
friendlyRole = "System"
|
||||
case "assistant":
|
||||
case MessageRoleAssistant:
|
||||
friendlyRole = "Assistant"
|
||||
default:
|
||||
friendlyRole = m.Role
|
||||
friendlyRole = string(m.Role)
|
||||
}
|
||||
return friendlyRole
|
||||
}
|
||||
@ -27,13 +35,13 @@ func (c *Conversation) GenerateTitle() error {
|
||||
|
||||
messages := []Message{
|
||||
{
|
||||
Role: "user",
|
||||
Role: MessageRoleUser,
|
||||
OriginalContent: prompt,
|
||||
},
|
||||
}
|
||||
|
||||
model := "gpt-3.5-turbo" // use cheap model to generate title
|
||||
response, err := CreateChatCompletion(model, messages, 25)
|
||||
response, err := CreateChatCompletion(model, messages, 25, nil)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
@ -16,7 +16,7 @@ func CreateChatCompletionRequest(model string, messages []Message, maxTokens int
|
||||
chatCompletionMessages := []openai.ChatCompletionMessage{}
|
||||
for _, m := range messages {
|
||||
message := openai.ChatCompletionMessage{
|
||||
Role: m.Role,
|
||||
Role: string(m.Role),
|
||||
Content: m.OriginalContent,
|
||||
}
|
||||
if m.ToolCallID.Valid {
|
||||
@ -53,7 +53,7 @@ func CreateChatCompletionRequest(model string, messages []Message, maxTokens int
|
||||
// CreateChatCompletion submits a Chat Completion API request and returns the
|
||||
// response. CreateChatCompletion will recursively call itself in the case of
|
||||
// tool calls, until a response is received with the final user-facing output.
|
||||
func CreateChatCompletion(model string, messages []Message, maxTokens int) (string, error) {
|
||||
func CreateChatCompletion(model string, messages []Message, maxTokens int, replies *[]Message) (string, error) {
|
||||
client := openai.NewClient(*config.OpenAI.APIKey)
|
||||
req := CreateChatCompletionRequest(model, messages, maxTokens)
|
||||
resp, err := client.CreateChatCompletion(context.Background(), req)
|
||||
@ -64,25 +64,33 @@ func CreateChatCompletion(model string, messages []Message, maxTokens int) (stri
|
||||
choice := resp.Choices[0]
|
||||
|
||||
if len(choice.Message.ToolCalls) > 0 {
|
||||
if choice.Message.Content != "" {
|
||||
return "", fmt.Errorf("Model replied with user-facing content in addition to tool calls. Unsupported.")
|
||||
}
|
||||
|
||||
// Append the assistant's reply with its request for tool calls
|
||||
toolCallJson, _ := json.Marshal(choice.Message.ToolCalls)
|
||||
messages = append(messages, Message{
|
||||
Role: "assistant",
|
||||
assistantReply := Message{
|
||||
Role: MessageRoleAssistant,
|
||||
ToolCalls: sql.NullString{String: string(toolCallJson), Valid: true},
|
||||
})
|
||||
}
|
||||
|
||||
toolReplies, err := ExecuteToolCalls(choice.Message.ToolCalls)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
if replies != nil {
|
||||
*replies = append(append(*replies, assistantReply), toolReplies...)
|
||||
}
|
||||
|
||||
messages = append(append(messages, assistantReply), toolReplies...)
|
||||
// Recurse into CreateChatCompletion with the tool call replies added
|
||||
// to the original messages
|
||||
return CreateChatCompletion(model, append(messages, toolReplies...), maxTokens)
|
||||
return CreateChatCompletion(model, append(messages, toolReplies...), maxTokens, replies)
|
||||
}
|
||||
|
||||
if replies != nil {
|
||||
*replies = append(*replies, Message{
|
||||
Role: MessageRoleAssistant,
|
||||
OriginalContent: choice.Message.Content,
|
||||
})
|
||||
}
|
||||
|
||||
// Return the user-facing message.
|
||||
@ -92,7 +100,7 @@ func CreateChatCompletion(model string, messages []Message, maxTokens int) (stri
|
||||
// CreateChatCompletionStream submits a streaming Chat Completion API request
|
||||
// and both returns and streams the response to the provided output channel.
|
||||
// May return a partial response if an error occurs mid-stream.
|
||||
func CreateChatCompletionStream(model string, messages []Message, maxTokens int, output chan<- string) (string, error) {
|
||||
func CreateChatCompletionStream(model string, messages []Message, maxTokens int, output chan<- string, replies *[]Message) (string, error) {
|
||||
client := openai.NewClient(*config.OpenAI.APIKey)
|
||||
req := CreateChatCompletionRequest(model, messages, maxTokens)
|
||||
|
||||
@ -137,25 +145,35 @@ func CreateChatCompletionStream(model string, messages []Message, maxTokens int,
|
||||
}
|
||||
|
||||
if len(toolCalls) > 0 {
|
||||
if content.String() != "" {
|
||||
return "", fmt.Errorf("Model replied with user-facing content in addition to tool calls. Unsupported.")
|
||||
}
|
||||
|
||||
// Append the assistant's reply with its request for tool calls
|
||||
toolCallJson, _ := json.Marshal(toolCalls)
|
||||
messages = append(messages, Message{
|
||||
Role: "assistant",
|
||||
ToolCalls: sql.NullString{String: string(toolCallJson), Valid: true},
|
||||
})
|
||||
|
||||
assistantReply := Message{
|
||||
Role: MessageRoleAssistant,
|
||||
OriginalContent: content.String(),
|
||||
ToolCalls: sql.NullString{String: string(toolCallJson), Valid: true},
|
||||
}
|
||||
|
||||
toolReplies, err := ExecuteToolCalls(toolCalls)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
if replies != nil {
|
||||
*replies = append(append(*replies, assistantReply), toolReplies...)
|
||||
}
|
||||
|
||||
// Recurse into CreateChatCompletionStream with the tool call replies
|
||||
// added to the original messages
|
||||
return CreateChatCompletionStream(model, append(messages, toolReplies...), maxTokens, output)
|
||||
messages = append(append(messages, assistantReply), toolReplies...)
|
||||
return CreateChatCompletionStream(model, messages, maxTokens, output, replies)
|
||||
}
|
||||
|
||||
if replies != nil {
|
||||
*replies = append(*replies, Message{
|
||||
Role: MessageRoleAssistant,
|
||||
OriginalContent: content.String(),
|
||||
})
|
||||
}
|
||||
|
||||
return content.String(), err
|
||||
|
@ -23,7 +23,7 @@ type Message struct {
|
||||
ConversationID uint `gorm:"foreignKey:ConversationID"`
|
||||
Conversation Conversation
|
||||
OriginalContent string
|
||||
Role string // one of: 'user', 'assistant', 'tool'
|
||||
Role MessageRole // one of: 'system', 'user', 'assistant', 'tool'
|
||||
CreatedAt time.Time
|
||||
ToolCallID sql.NullString
|
||||
ToolCalls sql.NullString // a json-encoded array of tool calls from the model
|
||||
|
Loading…
Reference in New Issue
Block a user