Update CreateChatCompletion behavior

When the last message in the passed messages slice is an assistant
message, treat it as a partial message that is being continued, and
include its content in the newly created reply

Update TUI code to handle new behavior
This commit is contained in:
Matt Low 2024-03-22 17:51:01 +00:00
parent 3185b2d7d6
commit 91c74d9e1e
4 changed files with 127 additions and 69 deletions

View File

@ -41,18 +41,28 @@ type RequestParameters struct {
ToolBag []Tool ToolBag []Tool
} }
func (m *MessageRole) IsAssistant() bool {
switch *m {
case MessageRoleAssistant, MessageRoleToolCall:
return true
}
return false
}
// FriendlyRole returns a human friendly signifier for the message's role. // FriendlyRole returns a human friendly signifier for the message's role.
func (m *MessageRole) FriendlyRole() string { func (m *MessageRole) FriendlyRole() string {
var friendlyRole string
switch *m { switch *m {
case MessageRoleUser: case MessageRoleUser:
friendlyRole = "You" return "You"
case MessageRoleSystem: case MessageRoleSystem:
friendlyRole = "System" return "System"
case MessageRoleAssistant: case MessageRoleAssistant:
friendlyRole = "Assistant" return "Assistant"
case MessageRoleToolCall:
return "Tool Call"
case MessageRoleToolResult:
return "Tool Result"
default: default:
friendlyRole = string(*m) return string(*m)
} }
return friendlyRole
} }

View File

@ -42,10 +42,12 @@ type OriginalContent struct {
} }
type Response struct { type Response struct {
Id string `json:"id"` Id string `json:"id"`
Type string `json:"type"` Type string `json:"type"`
Role string `json:"role"` Role string `json:"role"`
Content []OriginalContent `json:"content"` Content []OriginalContent `json:"content"`
StopReason string `json:"stop_reason"`
StopSequence string `json:"stop_sequence"`
} }
const FUNCTION_STOP_SEQUENCE = "</function_calls>" const FUNCTION_STOP_SEQUENCE = "</function_calls>"
@ -147,6 +149,10 @@ func (c *AnthropicClient) CreateChatCompletion(
messages []model.Message, messages []model.Message,
callback provider.ReplyCallback, callback provider.ReplyCallback,
) (string, error) { ) (string, error) {
if len(messages) == 0 {
return "", fmt.Errorf("Can't create completion from no messages")
}
request := buildRequest(params, messages) request := buildRequest(params, messages)
resp, err := sendRequest(ctx, c, request) resp, err := sendRequest(ctx, c, request)
@ -162,6 +168,14 @@ func (c *AnthropicClient) CreateChatCompletion(
} }
sb := strings.Builder{} sb := strings.Builder{}
lastMessage := messages[len(messages)-1]
if lastMessage.Role.IsAssistant() {
// this is a continuation of a previous assistant reply, so we'll
// include its contents in the final result
sb.WriteString(lastMessage.Content)
}
for _, content := range response.Content { for _, content := range response.Content {
var reply model.Message var reply model.Message
switch content.Type { switch content.Type {
@ -189,6 +203,10 @@ func (c *AnthropicClient) CreateChatCompletionStream(
callback provider.ReplyCallback, callback provider.ReplyCallback,
output chan<- string, output chan<- string,
) (string, error) { ) (string, error) {
if len(messages) == 0 {
return "", fmt.Errorf("Can't create completion from no messages")
}
request := buildRequest(params, messages) request := buildRequest(params, messages)
request.Stream = true request.Stream = true
@ -198,11 +216,18 @@ func (c *AnthropicClient) CreateChatCompletionStream(
} }
defer resp.Body.Close() defer resp.Body.Close()
scanner := bufio.NewScanner(resp.Body)
sb := strings.Builder{} sb := strings.Builder{}
isToolCall := false lastMessage := messages[len(messages)-1]
continuation := false
if messages[len(messages)-1].Role.IsAssistant() {
// this is a continuation of a previous assistant reply, so we'll
// include its contents in the final result
sb.WriteString(lastMessage.Content)
continuation = true
}
scanner := bufio.NewScanner(resp.Body)
for scanner.Scan() { for scanner.Scan() {
line := scanner.Text() line := scanner.Text()
line = strings.TrimSpace(line) line = strings.TrimSpace(line)
@ -277,26 +302,21 @@ func (c *AnthropicClient) CreateChatCompletionStream(
return content, fmt.Errorf("reached </function_calls> stop sequence but no opening tag found") return content, fmt.Errorf("reached </function_calls> stop sequence but no opening tag found")
} }
isToolCall = true
funcCallXml := content[start:]
funcCallXml += FUNCTION_STOP_SEQUENCE
sb.WriteString(FUNCTION_STOP_SEQUENCE) sb.WriteString(FUNCTION_STOP_SEQUENCE)
output <- FUNCTION_STOP_SEQUENCE output <- FUNCTION_STOP_SEQUENCE
// Extract function calls funcCallXml := content[start:] + FUNCTION_STOP_SEQUENCE
var functionCalls XMLFunctionCalls var functionCalls XMLFunctionCalls
err := xml.Unmarshal([]byte(sb.String()), &functionCalls) err := xml.Unmarshal([]byte(funcCallXml), &functionCalls)
if err != nil { if err != nil {
return "", fmt.Errorf("failed to unmarshal function_calls: %v", err) return "", fmt.Errorf("failed to unmarshal function_calls: %v", err)
} }
// Execute function calls
toolCall := model.Message{ toolCall := model.Message{
Role: model.MessageRoleToolCall, Role: model.MessageRoleToolCall,
// xml stripped from content // function call xml stripped from content for model interop
Content: content[:start], Content: strings.TrimSpace(content[:start]),
ToolCalls: convertXMLFunctionCallsToToolCalls(functionCalls), ToolCalls: convertXMLFunctionCallsToToolCalls(functionCalls),
} }
@ -305,33 +325,36 @@ func (c *AnthropicClient) CreateChatCompletionStream(
return "", err return "", err
} }
toolReply := model.Message{ toolResult := model.Message{
Role: model.MessageRoleToolResult, Role: model.MessageRoleToolResult,
ToolResults: toolResults, ToolResults: toolResults,
} }
if callback != nil { if callback != nil {
callback(toolCall) callback(toolCall)
callback(toolReply) callback(toolResult)
} }
// Recurse into CreateChatCompletionStream with the tool call replies if continuation {
// added to the original messages messages[len(messages)-1] = toolCall
messages = append(append(messages, toolCall), toolReply) } else {
messages = append(messages, toolCall)
}
messages = append(messages, toolResult)
return c.CreateChatCompletionStream(ctx, params, messages, callback, output) return c.CreateChatCompletionStream(ctx, params, messages, callback, output)
} }
} }
case "message_stop": case "message_stop":
// return the completed message // return the completed message
content := sb.String()
if callback != nil { if callback != nil {
if !isToolCall { callback(model.Message{
callback(model.Message{ Role: model.MessageRoleAssistant,
Role: model.MessageRoleAssistant, Content: content,
Content: sb.String(), })
})
}
} }
return sb.String(), nil return content, nil
case "error": case "error":
return sb.String(), fmt.Errorf("an error occurred: %s", event["error"]) return sb.String(), fmt.Errorf("an error occurred: %s", event["error"])
default: default:

View File

@ -137,7 +137,15 @@ func handleToolCalls(
params model.RequestParameters, params model.RequestParameters,
content string, content string,
toolCalls []openai.ToolCall, toolCalls []openai.ToolCall,
callback provider.ReplyCallback,
messages []model.Message,
) ([]model.Message, error) { ) ([]model.Message, error) {
lastMessage := messages[len(messages)-1]
continuation := false
if lastMessage.Role.IsAssistant() {
continuation = true
}
toolCall := model.Message{ toolCall := model.Message{
Role: model.MessageRoleToolCall, Role: model.MessageRoleToolCall,
Content: content, Content: content,
@ -154,7 +162,19 @@ func handleToolCalls(
ToolResults: toolResults, ToolResults: toolResults,
} }
return []model.Message{toolCall, toolResult}, nil if callback != nil {
callback(toolCall)
callback(toolResult)
}
if continuation {
messages[len(messages)-1] = toolCall
} else {
messages = append(messages, toolCall)
}
messages = append(messages, toolResult)
return messages, nil
} }
func (c *OpenAIClient) CreateChatCompletion( func (c *OpenAIClient) CreateChatCompletion(
@ -163,6 +183,10 @@ func (c *OpenAIClient) CreateChatCompletion(
messages []model.Message, messages []model.Message,
callback provider.ReplyCallback, callback provider.ReplyCallback,
) (string, error) { ) (string, error) {
if len(messages) == 0 {
return "", fmt.Errorf("Can't create completion from no messages")
}
client := openai.NewClient(c.APIKey) client := openai.NewClient(c.APIKey)
req := createChatCompletionRequest(c, params, messages) req := createChatCompletionRequest(c, params, messages)
resp, err := client.CreateChatCompletion(ctx, req) resp, err := client.CreateChatCompletion(ctx, req)
@ -172,32 +196,33 @@ func (c *OpenAIClient) CreateChatCompletion(
choice := resp.Choices[0] choice := resp.Choices[0]
var content string
lastMessage := messages[len(messages)-1]
if lastMessage.Role.IsAssistant() {
content = lastMessage.Content + choice.Message.Content
} else {
content = choice.Message.Content
}
toolCalls := choice.Message.ToolCalls toolCalls := choice.Message.ToolCalls
if len(toolCalls) > 0 { if len(toolCalls) > 0 {
results, err := handleToolCalls(params, choice.Message.Content, toolCalls) messages, err := handleToolCalls(params, content, toolCalls, callback, messages)
if err != nil { if err != nil {
return "", err return content, err
}
if callback != nil {
for _, result := range results {
callback(result)
}
} }
// Recurse into CreateChatCompletion with the tool call replies
messages = append(messages, results...)
return c.CreateChatCompletion(ctx, params, messages, callback) return c.CreateChatCompletion(ctx, params, messages, callback)
} }
if callback != nil { if callback != nil {
callback(model.Message{ callback(model.Message{
Role: model.MessageRoleAssistant, Role: model.MessageRoleAssistant,
Content: choice.Message.Content, Content: content,
}) })
} }
// Return the user-facing message. // Return the user-facing message.
return choice.Message.Content, nil return content, nil
} }
func (c *OpenAIClient) CreateChatCompletionStream( func (c *OpenAIClient) CreateChatCompletionStream(
@ -207,6 +232,10 @@ func (c *OpenAIClient) CreateChatCompletionStream(
callback provider.ReplyCallback, callback provider.ReplyCallback,
output chan<- string, output chan<- string,
) (string, error) { ) (string, error) {
if len(messages) == 0 {
return "", fmt.Errorf("Can't create completion from no messages")
}
client := openai.NewClient(c.APIKey) client := openai.NewClient(c.APIKey)
req := createChatCompletionRequest(c, params, messages) req := createChatCompletionRequest(c, params, messages)
@ -219,6 +248,11 @@ func (c *OpenAIClient) CreateChatCompletionStream(
content := strings.Builder{} content := strings.Builder{}
toolCalls := []openai.ToolCall{} toolCalls := []openai.ToolCall{}
lastMessage := messages[len(messages)-1]
if lastMessage.Role.IsAssistant() {
content.WriteString(lastMessage.Content)
}
// Iterate stream segments // Iterate stream segments
for { for {
response, e := stream.Recv() response, e := stream.Recv()
@ -251,19 +285,12 @@ func (c *OpenAIClient) CreateChatCompletionStream(
} }
if len(toolCalls) > 0 { if len(toolCalls) > 0 {
results, err := handleToolCalls(params, content.String(), toolCalls) messages, err := handleToolCalls(params, content.String(), toolCalls, callback, messages)
if err != nil { if err != nil {
return content.String(), err return content.String(), err
} }
if callback != nil {
for _, result := range results {
callback(result)
}
}
// Recurse into CreateChatCompletionStream with the tool call replies // Recurse into CreateChatCompletionStream with the tool call replies
messages = append(messages, results...)
return c.CreateChatCompletionStream(ctx, params, messages, callback, output) return c.CreateChatCompletionStream(ctx, params, messages, callback, output)
} else { } else {
if callback != nil { if callback != nil {

View File

@ -152,6 +152,7 @@ func (m model) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
case "ctrl+c": case "ctrl+c":
if m.waitingForReply { if m.waitingForReply {
m.stopSignal <- "" m.stopSignal <- ""
return m, nil
} else { } else {
return m, tea.Quit return m, tea.Quit
} }
@ -192,7 +193,7 @@ func (m model) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
case msgResponseChunk: case msgResponseChunk:
chunk := string(msg) chunk := string(msg)
last := len(m.messages) - 1 last := len(m.messages) - 1
if last >= 0 && m.messages[last].Role == models.MessageRoleAssistant { if last >= 0 && m.messages[last].Role.IsAssistant() {
m.setMessageContents(last, m.messages[last].Content+chunk) m.setMessageContents(last, m.messages[last].Content+chunk)
} else { } else {
m.addMessage(models.Message{ m.addMessage(models.Message{
@ -205,17 +206,16 @@ func (m model) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
case msgAssistantReply: case msgAssistantReply:
// the last reply that was being worked on is finished // the last reply that was being worked on is finished
reply := models.Message(msg) reply := models.Message(msg)
reply.Content = strings.TrimSpace(reply.Content)
last := len(m.messages) - 1 last := len(m.messages) - 1
if last < 0 { if last < 0 {
panic("Unexpected empty messages handling msgReply") panic("Unexpected empty messages handling msgAssistantReply")
} }
m.setMessageContents(last, strings.TrimSpace(m.messages[last].Content))
if m.messages[last].Role == models.MessageRoleAssistant { if reply.Role.IsAssistant() && m.messages[last].Role.IsAssistant() {
// the last message was an assistant message, so this is a continuation // this was a continuation, so replace the previous message with the completed reply
if reply.Role == models.MessageRoleToolCall { m.setMessage(last, reply)
// update last message rrole to tool call
m.messages[last].Role = models.MessageRoleToolCall
}
} else { } else {
m.addMessage(reply) m.addMessage(reply)
} }
@ -496,9 +496,8 @@ func (m *model) handleMessagesKey(msg tea.KeyMsg) tea.Cmd {
scrollIntoView(&m.content, offset, 0.1) scrollIntoView(&m.content, offset, 0.1)
} }
case "ctrl+r": case "ctrl+r":
// resubmit the conversation with all messages up until and including // resubmit the conversation with all messages up until and including the selected message
// the selected message if m.waitingForReply || len(m.messages) == 0 {
if len(m.messages) == 0 {
return nil return nil
} }
m.messages = m.messages[:m.selectedMessage+1] m.messages = m.messages[:m.selectedMessage+1]
@ -543,13 +542,12 @@ func (m *model) handleInputKey(msg tea.KeyMsg) tea.Cmd {
return wrapError(err) return wrapError(err)
} }
// ensure all messages up to the one we're about to add are // ensure all messages up to the one we're about to add are persisted
// persistent
cmd := m.persistConversation() cmd := m.persistConversation()
if cmd != nil { if cmd != nil {
return cmd return cmd
} }
// persist our new message, returning with any possible errors
savedReply, err := m.ctx.Store.AddReply(m.conversation, reply) savedReply, err := m.ctx.Store.AddReply(m.conversation, reply)
if err != nil { if err != nil {
return wrapError(err) return wrapError(err)