Update CreateChatCompletion behavior

When the last message in the passed messages slice is an assistant
message, treat it as a partial message that is being continued, and
include its content in the newly created reply

Update TUI code to handle new behavior
This commit is contained in:
Matt Low 2024-03-22 17:51:01 +00:00
parent 3185b2d7d6
commit 91c74d9e1e
4 changed files with 127 additions and 69 deletions

View File

@ -41,18 +41,28 @@ type RequestParameters struct {
ToolBag []Tool
}
func (m *MessageRole) IsAssistant() bool {
switch *m {
case MessageRoleAssistant, MessageRoleToolCall:
return true
}
return false
}
// FriendlyRole returns a human friendly signifier for the message's role.
func (m *MessageRole) FriendlyRole() string {
var friendlyRole string
switch *m {
case MessageRoleUser:
friendlyRole = "You"
return "You"
case MessageRoleSystem:
friendlyRole = "System"
return "System"
case MessageRoleAssistant:
friendlyRole = "Assistant"
return "Assistant"
case MessageRoleToolCall:
return "Tool Call"
case MessageRoleToolResult:
return "Tool Result"
default:
friendlyRole = string(*m)
return string(*m)
}
return friendlyRole
}

View File

@ -42,10 +42,12 @@ type OriginalContent struct {
}
type Response struct {
Id string `json:"id"`
Type string `json:"type"`
Role string `json:"role"`
Content []OriginalContent `json:"content"`
Id string `json:"id"`
Type string `json:"type"`
Role string `json:"role"`
Content []OriginalContent `json:"content"`
StopReason string `json:"stop_reason"`
StopSequence string `json:"stop_sequence"`
}
const FUNCTION_STOP_SEQUENCE = "</function_calls>"
@ -147,6 +149,10 @@ func (c *AnthropicClient) CreateChatCompletion(
messages []model.Message,
callback provider.ReplyCallback,
) (string, error) {
if len(messages) == 0 {
return "", fmt.Errorf("Can't create completion from no messages")
}
request := buildRequest(params, messages)
resp, err := sendRequest(ctx, c, request)
@ -162,6 +168,14 @@ func (c *AnthropicClient) CreateChatCompletion(
}
sb := strings.Builder{}
lastMessage := messages[len(messages)-1]
if lastMessage.Role.IsAssistant() {
// this is a continuation of a previous assistant reply, so we'll
// include its contents in the final result
sb.WriteString(lastMessage.Content)
}
for _, content := range response.Content {
var reply model.Message
switch content.Type {
@ -189,6 +203,10 @@ func (c *AnthropicClient) CreateChatCompletionStream(
callback provider.ReplyCallback,
output chan<- string,
) (string, error) {
if len(messages) == 0 {
return "", fmt.Errorf("Can't create completion from no messages")
}
request := buildRequest(params, messages)
request.Stream = true
@ -198,11 +216,18 @@ func (c *AnthropicClient) CreateChatCompletionStream(
}
defer resp.Body.Close()
scanner := bufio.NewScanner(resp.Body)
sb := strings.Builder{}
isToolCall := false
lastMessage := messages[len(messages)-1]
continuation := false
if messages[len(messages)-1].Role.IsAssistant() {
// this is a continuation of a previous assistant reply, so we'll
// include its contents in the final result
sb.WriteString(lastMessage.Content)
continuation = true
}
scanner := bufio.NewScanner(resp.Body)
for scanner.Scan() {
line := scanner.Text()
line = strings.TrimSpace(line)
@ -277,26 +302,21 @@ func (c *AnthropicClient) CreateChatCompletionStream(
return content, fmt.Errorf("reached </function_calls> stop sequence but no opening tag found")
}
isToolCall = true
funcCallXml := content[start:]
funcCallXml += FUNCTION_STOP_SEQUENCE
sb.WriteString(FUNCTION_STOP_SEQUENCE)
output <- FUNCTION_STOP_SEQUENCE
// Extract function calls
funcCallXml := content[start:] + FUNCTION_STOP_SEQUENCE
var functionCalls XMLFunctionCalls
err := xml.Unmarshal([]byte(sb.String()), &functionCalls)
err := xml.Unmarshal([]byte(funcCallXml), &functionCalls)
if err != nil {
return "", fmt.Errorf("failed to unmarshal function_calls: %v", err)
}
// Execute function calls
toolCall := model.Message{
Role: model.MessageRoleToolCall,
// xml stripped from content
Content: content[:start],
// function call xml stripped from content for model interop
Content: strings.TrimSpace(content[:start]),
ToolCalls: convertXMLFunctionCallsToToolCalls(functionCalls),
}
@ -305,33 +325,36 @@ func (c *AnthropicClient) CreateChatCompletionStream(
return "", err
}
toolReply := model.Message{
toolResult := model.Message{
Role: model.MessageRoleToolResult,
ToolResults: toolResults,
}
if callback != nil {
callback(toolCall)
callback(toolReply)
callback(toolResult)
}
// Recurse into CreateChatCompletionStream with the tool call replies
// added to the original messages
messages = append(append(messages, toolCall), toolReply)
if continuation {
messages[len(messages)-1] = toolCall
} else {
messages = append(messages, toolCall)
}
messages = append(messages, toolResult)
return c.CreateChatCompletionStream(ctx, params, messages, callback, output)
}
}
case "message_stop":
// return the completed message
content := sb.String()
if callback != nil {
if !isToolCall {
callback(model.Message{
Role: model.MessageRoleAssistant,
Content: sb.String(),
})
}
callback(model.Message{
Role: model.MessageRoleAssistant,
Content: content,
})
}
return sb.String(), nil
return content, nil
case "error":
return sb.String(), fmt.Errorf("an error occurred: %s", event["error"])
default:

View File

@ -137,7 +137,15 @@ func handleToolCalls(
params model.RequestParameters,
content string,
toolCalls []openai.ToolCall,
callback provider.ReplyCallback,
messages []model.Message,
) ([]model.Message, error) {
lastMessage := messages[len(messages)-1]
continuation := false
if lastMessage.Role.IsAssistant() {
continuation = true
}
toolCall := model.Message{
Role: model.MessageRoleToolCall,
Content: content,
@ -154,7 +162,19 @@ func handleToolCalls(
ToolResults: toolResults,
}
return []model.Message{toolCall, toolResult}, nil
if callback != nil {
callback(toolCall)
callback(toolResult)
}
if continuation {
messages[len(messages)-1] = toolCall
} else {
messages = append(messages, toolCall)
}
messages = append(messages, toolResult)
return messages, nil
}
func (c *OpenAIClient) CreateChatCompletion(
@ -163,6 +183,10 @@ func (c *OpenAIClient) CreateChatCompletion(
messages []model.Message,
callback provider.ReplyCallback,
) (string, error) {
if len(messages) == 0 {
return "", fmt.Errorf("Can't create completion from no messages")
}
client := openai.NewClient(c.APIKey)
req := createChatCompletionRequest(c, params, messages)
resp, err := client.CreateChatCompletion(ctx, req)
@ -172,32 +196,33 @@ func (c *OpenAIClient) CreateChatCompletion(
choice := resp.Choices[0]
var content string
lastMessage := messages[len(messages)-1]
if lastMessage.Role.IsAssistant() {
content = lastMessage.Content + choice.Message.Content
} else {
content = choice.Message.Content
}
toolCalls := choice.Message.ToolCalls
if len(toolCalls) > 0 {
results, err := handleToolCalls(params, choice.Message.Content, toolCalls)
messages, err := handleToolCalls(params, content, toolCalls, callback, messages)
if err != nil {
return "", err
}
if callback != nil {
for _, result := range results {
callback(result)
}
return content, err
}
// Recurse into CreateChatCompletion with the tool call replies
messages = append(messages, results...)
return c.CreateChatCompletion(ctx, params, messages, callback)
}
if callback != nil {
callback(model.Message{
Role: model.MessageRoleAssistant,
Content: choice.Message.Content,
Content: content,
})
}
// Return the user-facing message.
return choice.Message.Content, nil
return content, nil
}
func (c *OpenAIClient) CreateChatCompletionStream(
@ -207,6 +232,10 @@ func (c *OpenAIClient) CreateChatCompletionStream(
callback provider.ReplyCallback,
output chan<- string,
) (string, error) {
if len(messages) == 0 {
return "", fmt.Errorf("Can't create completion from no messages")
}
client := openai.NewClient(c.APIKey)
req := createChatCompletionRequest(c, params, messages)
@ -219,6 +248,11 @@ func (c *OpenAIClient) CreateChatCompletionStream(
content := strings.Builder{}
toolCalls := []openai.ToolCall{}
lastMessage := messages[len(messages)-1]
if lastMessage.Role.IsAssistant() {
content.WriteString(lastMessage.Content)
}
// Iterate stream segments
for {
response, e := stream.Recv()
@ -251,19 +285,12 @@ func (c *OpenAIClient) CreateChatCompletionStream(
}
if len(toolCalls) > 0 {
results, err := handleToolCalls(params, content.String(), toolCalls)
messages, err := handleToolCalls(params, content.String(), toolCalls, callback, messages)
if err != nil {
return content.String(), err
}
if callback != nil {
for _, result := range results {
callback(result)
}
}
// Recurse into CreateChatCompletionStream with the tool call replies
messages = append(messages, results...)
return c.CreateChatCompletionStream(ctx, params, messages, callback, output)
} else {
if callback != nil {

View File

@ -152,6 +152,7 @@ func (m model) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
case "ctrl+c":
if m.waitingForReply {
m.stopSignal <- ""
return m, nil
} else {
return m, tea.Quit
}
@ -192,7 +193,7 @@ func (m model) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
case msgResponseChunk:
chunk := string(msg)
last := len(m.messages) - 1
if last >= 0 && m.messages[last].Role == models.MessageRoleAssistant {
if last >= 0 && m.messages[last].Role.IsAssistant() {
m.setMessageContents(last, m.messages[last].Content+chunk)
} else {
m.addMessage(models.Message{
@ -205,17 +206,16 @@ func (m model) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
case msgAssistantReply:
// the last reply that was being worked on is finished
reply := models.Message(msg)
reply.Content = strings.TrimSpace(reply.Content)
last := len(m.messages) - 1
if last < 0 {
panic("Unexpected empty messages handling msgReply")
panic("Unexpected empty messages handling msgAssistantReply")
}
m.setMessageContents(last, strings.TrimSpace(m.messages[last].Content))
if m.messages[last].Role == models.MessageRoleAssistant {
// the last message was an assistant message, so this is a continuation
if reply.Role == models.MessageRoleToolCall {
// update last message rrole to tool call
m.messages[last].Role = models.MessageRoleToolCall
}
if reply.Role.IsAssistant() && m.messages[last].Role.IsAssistant() {
// this was a continuation, so replace the previous message with the completed reply
m.setMessage(last, reply)
} else {
m.addMessage(reply)
}
@ -496,9 +496,8 @@ func (m *model) handleMessagesKey(msg tea.KeyMsg) tea.Cmd {
scrollIntoView(&m.content, offset, 0.1)
}
case "ctrl+r":
// resubmit the conversation with all messages up until and including
// the selected message
if len(m.messages) == 0 {
// resubmit the conversation with all messages up until and including the selected message
if m.waitingForReply || len(m.messages) == 0 {
return nil
}
m.messages = m.messages[:m.selectedMessage+1]
@ -543,13 +542,12 @@ func (m *model) handleInputKey(msg tea.KeyMsg) tea.Cmd {
return wrapError(err)
}
// ensure all messages up to the one we're about to add are
// persistent
// ensure all messages up to the one we're about to add are persisted
cmd := m.persistConversation()
if cmd != nil {
return cmd
}
// persist our new message, returning with any possible errors
savedReply, err := m.ctx.Store.AddReply(m.conversation, reply)
if err != nil {
return wrapError(err)