Add *Message[] parameter to CreateChatCompletion methods
Allows replies (tool calls, user-facing messges) to be added in sequence as CreateChatCompleion* recurses into itself. Cleaned up cmd.go: no longer need to create a Message based on the string content response.
This commit is contained in:
parent
e850c340b7
commit
ed6ee9bea9
@ -53,9 +53,10 @@ func SystemPrompt() string {
|
|||||||
return systemPrompt
|
return systemPrompt
|
||||||
}
|
}
|
||||||
|
|
||||||
// LLMRequest prompts the LLM with the given Message, writes the (partial)
|
// LLMRequest prompts the LLM with the given messages, writing the response
|
||||||
// response to stdout, and returns the (partial) response or any errors.
|
// to stdout. Returns all reply messages added by the LLM, including any
|
||||||
func LLMRequest(messages []Message) (string, error) {
|
// function call messages.
|
||||||
|
func LLMRequest(messages []Message) ([]Message, error) {
|
||||||
// receiver receives the reponse from LLM
|
// receiver receives the reponse from LLM
|
||||||
receiver := make(chan string)
|
receiver := make(chan string)
|
||||||
defer close(receiver)
|
defer close(receiver)
|
||||||
@ -63,7 +64,8 @@ func LLMRequest(messages []Message) (string, error) {
|
|||||||
// start HandleDelayedContent goroutine to print received data to stdout
|
// start HandleDelayedContent goroutine to print received data to stdout
|
||||||
go HandleDelayedContent(receiver)
|
go HandleDelayedContent(receiver)
|
||||||
|
|
||||||
response, err := CreateChatCompletionStream(model, messages, maxTokens, receiver)
|
var replies []Message
|
||||||
|
response, err := CreateChatCompletionStream(model, messages, maxTokens, receiver, &replies)
|
||||||
if response != "" {
|
if response != "" {
|
||||||
if err != nil {
|
if err != nil {
|
||||||
Warn("Received partial response. Error: %v\n", err)
|
Warn("Received partial response. Error: %v\n", err)
|
||||||
@ -73,7 +75,7 @@ func LLMRequest(messages []Message) (string, error) {
|
|||||||
fmt.Println()
|
fmt.Println()
|
||||||
}
|
}
|
||||||
|
|
||||||
return response, err
|
return replies, err
|
||||||
}
|
}
|
||||||
|
|
||||||
// InputFromArgsOrEditor returns either the provided input from the args slice
|
// InputFromArgsOrEditor returns either the provided input from the args slice
|
||||||
@ -316,22 +318,23 @@ var replyCmd = &cobra.Command{
|
|||||||
messages = append(messages, userReply)
|
messages = append(messages, userReply)
|
||||||
|
|
||||||
RenderConversation(messages, true)
|
RenderConversation(messages, true)
|
||||||
assistantReply := Message{
|
// output an <Assistant> message heading
|
||||||
ConversationID: conversation.ID,
|
(&Message{
|
||||||
Role: "assistant",
|
Role: MessageRoleAssistant,
|
||||||
}
|
}).RenderTTY()
|
||||||
assistantReply.RenderTTY()
|
|
||||||
|
|
||||||
response, err := LLMRequest(messages)
|
replies, err := LLMRequest(messages)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
Fatal("Error fetching LLM response: %v\n", err)
|
Fatal("Error fetching LLM response: %v\n", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
assistantReply.OriginalContent = response
|
for _, reply := range replies {
|
||||||
|
reply.ConversationID = conversation.ID
|
||||||
|
|
||||||
err = store.SaveMessage(&assistantReply)
|
err = store.SaveMessage(&reply)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
Fatal("Could not save assistant reply: %v\n", err)
|
Warn("Could not save reply: %v\n", err)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
ValidArgsFunction: func(cmd *cobra.Command, args []string, toComplete string) ([]string, cobra.ShellCompDirective) {
|
ValidArgsFunction: func(cmd *cobra.Command, args []string, toComplete string) ([]string, cobra.ShellCompDirective) {
|
||||||
@ -379,22 +382,24 @@ var newCmd = &cobra.Command{
|
|||||||
}
|
}
|
||||||
|
|
||||||
RenderConversation(messages, true)
|
RenderConversation(messages, true)
|
||||||
reply := Message{
|
|
||||||
ConversationID: conversation.ID,
|
|
||||||
Role: "assistant",
|
|
||||||
}
|
|
||||||
reply.RenderTTY()
|
|
||||||
|
|
||||||
response, err := LLMRequest(messages)
|
// output an <Assistant> message heading
|
||||||
|
(&Message{
|
||||||
|
Role: MessageRoleAssistant,
|
||||||
|
}).RenderTTY()
|
||||||
|
|
||||||
|
replies, err := LLMRequest(messages)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
Fatal("Error fetching LLM response: %v\n", err)
|
Fatal("Error fetching LLM response: %v\n", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
reply.OriginalContent = response
|
for _, reply := range replies {
|
||||||
|
reply.ConversationID = conversation.ID
|
||||||
|
|
||||||
err = store.SaveMessage(&reply)
|
err = store.SaveMessage(&reply)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
Fatal("Could not save reply: %v\n", err)
|
Warn("Could not save reply: %v\n", err)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
err = conversation.GenerateTitle()
|
err = conversation.GenerateTitle()
|
||||||
@ -461,8 +466,9 @@ var retryCmd = &cobra.Command{
|
|||||||
}
|
}
|
||||||
|
|
||||||
var lastUserMessageIndex int
|
var lastUserMessageIndex int
|
||||||
|
// walk backwards through conversations to find last user message
|
||||||
for i := len(messages) - 1; i >= 0; i-- {
|
for i := len(messages) - 1; i >= 0; i-- {
|
||||||
if messages[i].Role == "user" {
|
if messages[i].Role == MessageRoleUser {
|
||||||
lastUserMessageIndex = i
|
lastUserMessageIndex = i
|
||||||
break
|
break
|
||||||
}
|
}
|
||||||
@ -471,22 +477,22 @@ var retryCmd = &cobra.Command{
|
|||||||
messages = messages[:lastUserMessageIndex+1]
|
messages = messages[:lastUserMessageIndex+1]
|
||||||
|
|
||||||
RenderConversation(messages, true)
|
RenderConversation(messages, true)
|
||||||
assistantReply := Message{
|
(&Message{
|
||||||
ConversationID: conversation.ID,
|
Role: MessageRoleAssistant,
|
||||||
Role: "assistant",
|
}).RenderTTY()
|
||||||
}
|
|
||||||
assistantReply.RenderTTY()
|
|
||||||
|
|
||||||
response, err := LLMRequest(messages)
|
replies, err := LLMRequest(messages)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
Fatal("Error fetching LLM response: %v\n", err)
|
Fatal("Error fetching LLM response: %v\n", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
assistantReply.OriginalContent = response
|
for _, reply := range replies {
|
||||||
|
reply.ConversationID = conversation.ID
|
||||||
|
|
||||||
err = store.SaveMessage(&assistantReply)
|
err = store.SaveMessage(&reply)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
Fatal("Could not save assistant reply: %v\n", err)
|
Warn("Could not save reply: %v\n", err)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
ValidArgsFunction: func(cmd *cobra.Command, args []string, toComplete string) ([]string, cobra.ShellCompDirective) {
|
ValidArgsFunction: func(cmd *cobra.Command, args []string, toComplete string) ([]string, cobra.ShellCompDirective) {
|
||||||
@ -522,22 +528,22 @@ var continueCmd = &cobra.Command{
|
|||||||
}
|
}
|
||||||
|
|
||||||
RenderConversation(messages, true)
|
RenderConversation(messages, true)
|
||||||
assistantReply := Message{
|
(&Message{
|
||||||
ConversationID: conversation.ID,
|
Role: MessageRoleAssistant,
|
||||||
Role: "assistant",
|
}).RenderTTY()
|
||||||
}
|
|
||||||
assistantReply.RenderTTY()
|
|
||||||
|
|
||||||
response, err := LLMRequest(messages)
|
replies, err := LLMRequest(messages)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
Fatal("Error fetching LLM response: %v\n", err)
|
Fatal("Error fetching LLM response: %v\n", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
assistantReply.OriginalContent = response
|
for _, reply := range replies {
|
||||||
|
reply.ConversationID = conversation.ID
|
||||||
|
|
||||||
err = store.SaveMessage(&assistantReply)
|
err = store.SaveMessage(&reply)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
Fatal("Could not save assistant reply: %v\n", err)
|
Warn("Could not save reply: %v\n", err)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
ValidArgsFunction: func(cmd *cobra.Command, args []string, toComplete string) ([]string, cobra.ShellCompDirective) {
|
ValidArgsFunction: func(cmd *cobra.Command, args []string, toComplete string) ([]string, cobra.ShellCompDirective) {
|
||||||
|
@ -5,18 +5,26 @@ import (
|
|||||||
"strings"
|
"strings"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
type MessageRole string
|
||||||
|
|
||||||
|
const (
|
||||||
|
MessageRoleUser MessageRole = "user"
|
||||||
|
MessageRoleAssistant MessageRole = "assistant"
|
||||||
|
MessageRoleSystem MessageRole = "system"
|
||||||
|
)
|
||||||
|
|
||||||
// FriendlyRole returns a human friendly signifier for the message's role.
|
// FriendlyRole returns a human friendly signifier for the message's role.
|
||||||
func (m *Message) FriendlyRole() string {
|
func (m *Message) FriendlyRole() string {
|
||||||
var friendlyRole string
|
var friendlyRole string
|
||||||
switch m.Role {
|
switch m.Role {
|
||||||
case "user":
|
case MessageRoleUser:
|
||||||
friendlyRole = "You"
|
friendlyRole = "You"
|
||||||
case "system":
|
case MessageRoleSystem:
|
||||||
friendlyRole = "System"
|
friendlyRole = "System"
|
||||||
case "assistant":
|
case MessageRoleAssistant:
|
||||||
friendlyRole = "Assistant"
|
friendlyRole = "Assistant"
|
||||||
default:
|
default:
|
||||||
friendlyRole = m.Role
|
friendlyRole = string(m.Role)
|
||||||
}
|
}
|
||||||
return friendlyRole
|
return friendlyRole
|
||||||
}
|
}
|
||||||
@ -27,13 +35,13 @@ func (c *Conversation) GenerateTitle() error {
|
|||||||
|
|
||||||
messages := []Message{
|
messages := []Message{
|
||||||
{
|
{
|
||||||
Role: "user",
|
Role: MessageRoleUser,
|
||||||
OriginalContent: prompt,
|
OriginalContent: prompt,
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
model := "gpt-3.5-turbo" // use cheap model to generate title
|
model := "gpt-3.5-turbo" // use cheap model to generate title
|
||||||
response, err := CreateChatCompletion(model, messages, 25)
|
response, err := CreateChatCompletion(model, messages, 25, nil)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
@ -16,7 +16,7 @@ func CreateChatCompletionRequest(model string, messages []Message, maxTokens int
|
|||||||
chatCompletionMessages := []openai.ChatCompletionMessage{}
|
chatCompletionMessages := []openai.ChatCompletionMessage{}
|
||||||
for _, m := range messages {
|
for _, m := range messages {
|
||||||
message := openai.ChatCompletionMessage{
|
message := openai.ChatCompletionMessage{
|
||||||
Role: m.Role,
|
Role: string(m.Role),
|
||||||
Content: m.OriginalContent,
|
Content: m.OriginalContent,
|
||||||
}
|
}
|
||||||
if m.ToolCallID.Valid {
|
if m.ToolCallID.Valid {
|
||||||
@ -53,7 +53,7 @@ func CreateChatCompletionRequest(model string, messages []Message, maxTokens int
|
|||||||
// CreateChatCompletion submits a Chat Completion API request and returns the
|
// CreateChatCompletion submits a Chat Completion API request and returns the
|
||||||
// response. CreateChatCompletion will recursively call itself in the case of
|
// response. CreateChatCompletion will recursively call itself in the case of
|
||||||
// tool calls, until a response is received with the final user-facing output.
|
// tool calls, until a response is received with the final user-facing output.
|
||||||
func CreateChatCompletion(model string, messages []Message, maxTokens int) (string, error) {
|
func CreateChatCompletion(model string, messages []Message, maxTokens int, replies *[]Message) (string, error) {
|
||||||
client := openai.NewClient(*config.OpenAI.APIKey)
|
client := openai.NewClient(*config.OpenAI.APIKey)
|
||||||
req := CreateChatCompletionRequest(model, messages, maxTokens)
|
req := CreateChatCompletionRequest(model, messages, maxTokens)
|
||||||
resp, err := client.CreateChatCompletion(context.Background(), req)
|
resp, err := client.CreateChatCompletion(context.Background(), req)
|
||||||
@ -64,25 +64,33 @@ func CreateChatCompletion(model string, messages []Message, maxTokens int) (stri
|
|||||||
choice := resp.Choices[0]
|
choice := resp.Choices[0]
|
||||||
|
|
||||||
if len(choice.Message.ToolCalls) > 0 {
|
if len(choice.Message.ToolCalls) > 0 {
|
||||||
if choice.Message.Content != "" {
|
|
||||||
return "", fmt.Errorf("Model replied with user-facing content in addition to tool calls. Unsupported.")
|
|
||||||
}
|
|
||||||
|
|
||||||
// Append the assistant's reply with its request for tool calls
|
// Append the assistant's reply with its request for tool calls
|
||||||
toolCallJson, _ := json.Marshal(choice.Message.ToolCalls)
|
toolCallJson, _ := json.Marshal(choice.Message.ToolCalls)
|
||||||
messages = append(messages, Message{
|
assistantReply := Message{
|
||||||
Role: "assistant",
|
Role: MessageRoleAssistant,
|
||||||
ToolCalls: sql.NullString{String: string(toolCallJson), Valid: true},
|
ToolCalls: sql.NullString{String: string(toolCallJson), Valid: true},
|
||||||
})
|
}
|
||||||
|
|
||||||
toolReplies, err := ExecuteToolCalls(choice.Message.ToolCalls)
|
toolReplies, err := ExecuteToolCalls(choice.Message.ToolCalls)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return "", err
|
return "", err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if replies != nil {
|
||||||
|
*replies = append(append(*replies, assistantReply), toolReplies...)
|
||||||
|
}
|
||||||
|
|
||||||
|
messages = append(append(messages, assistantReply), toolReplies...)
|
||||||
// Recurse into CreateChatCompletion with the tool call replies added
|
// Recurse into CreateChatCompletion with the tool call replies added
|
||||||
// to the original messages
|
// to the original messages
|
||||||
return CreateChatCompletion(model, append(messages, toolReplies...), maxTokens)
|
return CreateChatCompletion(model, append(messages, toolReplies...), maxTokens, replies)
|
||||||
|
}
|
||||||
|
|
||||||
|
if replies != nil {
|
||||||
|
*replies = append(*replies, Message{
|
||||||
|
Role: MessageRoleAssistant,
|
||||||
|
OriginalContent: choice.Message.Content,
|
||||||
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
// Return the user-facing message.
|
// Return the user-facing message.
|
||||||
@ -92,7 +100,7 @@ func CreateChatCompletion(model string, messages []Message, maxTokens int) (stri
|
|||||||
// CreateChatCompletionStream submits a streaming Chat Completion API request
|
// CreateChatCompletionStream submits a streaming Chat Completion API request
|
||||||
// and both returns and streams the response to the provided output channel.
|
// and both returns and streams the response to the provided output channel.
|
||||||
// May return a partial response if an error occurs mid-stream.
|
// May return a partial response if an error occurs mid-stream.
|
||||||
func CreateChatCompletionStream(model string, messages []Message, maxTokens int, output chan<- string) (string, error) {
|
func CreateChatCompletionStream(model string, messages []Message, maxTokens int, output chan<- string, replies *[]Message) (string, error) {
|
||||||
client := openai.NewClient(*config.OpenAI.APIKey)
|
client := openai.NewClient(*config.OpenAI.APIKey)
|
||||||
req := CreateChatCompletionRequest(model, messages, maxTokens)
|
req := CreateChatCompletionRequest(model, messages, maxTokens)
|
||||||
|
|
||||||
@ -137,25 +145,35 @@ func CreateChatCompletionStream(model string, messages []Message, maxTokens int,
|
|||||||
}
|
}
|
||||||
|
|
||||||
if len(toolCalls) > 0 {
|
if len(toolCalls) > 0 {
|
||||||
if content.String() != "" {
|
|
||||||
return "", fmt.Errorf("Model replied with user-facing content in addition to tool calls. Unsupported.")
|
|
||||||
}
|
|
||||||
|
|
||||||
// Append the assistant's reply with its request for tool calls
|
// Append the assistant's reply with its request for tool calls
|
||||||
toolCallJson, _ := json.Marshal(toolCalls)
|
toolCallJson, _ := json.Marshal(toolCalls)
|
||||||
messages = append(messages, Message{
|
|
||||||
Role: "assistant",
|
assistantReply := Message{
|
||||||
ToolCalls: sql.NullString{String: string(toolCallJson), Valid: true},
|
Role: MessageRoleAssistant,
|
||||||
})
|
OriginalContent: content.String(),
|
||||||
|
ToolCalls: sql.NullString{String: string(toolCallJson), Valid: true},
|
||||||
|
}
|
||||||
|
|
||||||
toolReplies, err := ExecuteToolCalls(toolCalls)
|
toolReplies, err := ExecuteToolCalls(toolCalls)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return "", err
|
return "", err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if replies != nil {
|
||||||
|
*replies = append(append(*replies, assistantReply), toolReplies...)
|
||||||
|
}
|
||||||
|
|
||||||
// Recurse into CreateChatCompletionStream with the tool call replies
|
// Recurse into CreateChatCompletionStream with the tool call replies
|
||||||
// added to the original messages
|
// added to the original messages
|
||||||
return CreateChatCompletionStream(model, append(messages, toolReplies...), maxTokens, output)
|
messages = append(append(messages, assistantReply), toolReplies...)
|
||||||
|
return CreateChatCompletionStream(model, messages, maxTokens, output, replies)
|
||||||
|
}
|
||||||
|
|
||||||
|
if replies != nil {
|
||||||
|
*replies = append(*replies, Message{
|
||||||
|
Role: MessageRoleAssistant,
|
||||||
|
OriginalContent: content.String(),
|
||||||
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
return content.String(), err
|
return content.String(), err
|
||||||
|
@ -23,7 +23,7 @@ type Message struct {
|
|||||||
ConversationID uint `gorm:"foreignKey:ConversationID"`
|
ConversationID uint `gorm:"foreignKey:ConversationID"`
|
||||||
Conversation Conversation
|
Conversation Conversation
|
||||||
OriginalContent string
|
OriginalContent string
|
||||||
Role string // one of: 'user', 'assistant', 'tool'
|
Role MessageRole // one of: 'system', 'user', 'assistant', 'tool'
|
||||||
CreatedAt time.Time
|
CreatedAt time.Time
|
||||||
ToolCallID sql.NullString
|
ToolCallID sql.NullString
|
||||||
ToolCalls sql.NullString // a json-encoded array of tool calls from the model
|
ToolCalls sql.NullString // a json-encoded array of tool calls from the model
|
||||||
|
Loading…
Reference in New Issue
Block a user