From 91c74d9e1efb066b9887118adbb2aa98b0cdce65 Mon Sep 17 00:00:00 2001 From: Matt Low Date: Fri, 22 Mar 2024 17:51:01 +0000 Subject: [PATCH] Update CreateChatCompletion behavior When the last message in the passed messages slice is an assistant message, treat it as a partial message that is being continued, and include its content in the newly created reply Update TUI code to handle new behavior --- pkg/lmcli/model/conversation.go | 22 +++++-- pkg/lmcli/provider/anthropic/anthropic.go | 79 +++++++++++++++-------- pkg/lmcli/provider/openai/openai.go | 67 +++++++++++++------ pkg/tui/tui.go | 28 ++++---- 4 files changed, 127 insertions(+), 69 deletions(-) diff --git a/pkg/lmcli/model/conversation.go b/pkg/lmcli/model/conversation.go index 5494b90..02b5c0e 100644 --- a/pkg/lmcli/model/conversation.go +++ b/pkg/lmcli/model/conversation.go @@ -41,18 +41,28 @@ type RequestParameters struct { ToolBag []Tool } +func (m *MessageRole) IsAssistant() bool { + switch *m { + case MessageRoleAssistant, MessageRoleToolCall: + return true + } + return false +} + // FriendlyRole returns a human friendly signifier for the message's role. func (m *MessageRole) FriendlyRole() string { - var friendlyRole string switch *m { case MessageRoleUser: - friendlyRole = "You" + return "You" case MessageRoleSystem: - friendlyRole = "System" + return "System" case MessageRoleAssistant: - friendlyRole = "Assistant" + return "Assistant" + case MessageRoleToolCall: + return "Tool Call" + case MessageRoleToolResult: + return "Tool Result" default: - friendlyRole = string(*m) + return string(*m) } - return friendlyRole } diff --git a/pkg/lmcli/provider/anthropic/anthropic.go b/pkg/lmcli/provider/anthropic/anthropic.go index cff568c..2aa7164 100644 --- a/pkg/lmcli/provider/anthropic/anthropic.go +++ b/pkg/lmcli/provider/anthropic/anthropic.go @@ -42,10 +42,12 @@ type OriginalContent struct { } type Response struct { - Id string `json:"id"` - Type string `json:"type"` - Role string `json:"role"` - Content []OriginalContent `json:"content"` + Id string `json:"id"` + Type string `json:"type"` + Role string `json:"role"` + Content []OriginalContent `json:"content"` + StopReason string `json:"stop_reason"` + StopSequence string `json:"stop_sequence"` } const FUNCTION_STOP_SEQUENCE = "" @@ -147,6 +149,10 @@ func (c *AnthropicClient) CreateChatCompletion( messages []model.Message, callback provider.ReplyCallback, ) (string, error) { + if len(messages) == 0 { + return "", fmt.Errorf("Can't create completion from no messages") + } + request := buildRequest(params, messages) resp, err := sendRequest(ctx, c, request) @@ -162,6 +168,14 @@ func (c *AnthropicClient) CreateChatCompletion( } sb := strings.Builder{} + + lastMessage := messages[len(messages)-1] + if lastMessage.Role.IsAssistant() { + // this is a continuation of a previous assistant reply, so we'll + // include its contents in the final result + sb.WriteString(lastMessage.Content) + } + for _, content := range response.Content { var reply model.Message switch content.Type { @@ -189,6 +203,10 @@ func (c *AnthropicClient) CreateChatCompletionStream( callback provider.ReplyCallback, output chan<- string, ) (string, error) { + if len(messages) == 0 { + return "", fmt.Errorf("Can't create completion from no messages") + } + request := buildRequest(params, messages) request.Stream = true @@ -198,11 +216,18 @@ func (c *AnthropicClient) CreateChatCompletionStream( } defer resp.Body.Close() - scanner := bufio.NewScanner(resp.Body) sb := strings.Builder{} - isToolCall := false + lastMessage := messages[len(messages)-1] + continuation := false + if messages[len(messages)-1].Role.IsAssistant() { + // this is a continuation of a previous assistant reply, so we'll + // include its contents in the final result + sb.WriteString(lastMessage.Content) + continuation = true + } + scanner := bufio.NewScanner(resp.Body) for scanner.Scan() { line := scanner.Text() line = strings.TrimSpace(line) @@ -277,26 +302,21 @@ func (c *AnthropicClient) CreateChatCompletionStream( return content, fmt.Errorf("reached stop sequence but no opening tag found") } - isToolCall = true - - funcCallXml := content[start:] - funcCallXml += FUNCTION_STOP_SEQUENCE - sb.WriteString(FUNCTION_STOP_SEQUENCE) output <- FUNCTION_STOP_SEQUENCE - // Extract function calls + funcCallXml := content[start:] + FUNCTION_STOP_SEQUENCE + var functionCalls XMLFunctionCalls - err := xml.Unmarshal([]byte(sb.String()), &functionCalls) + err := xml.Unmarshal([]byte(funcCallXml), &functionCalls) if err != nil { return "", fmt.Errorf("failed to unmarshal function_calls: %v", err) } - // Execute function calls toolCall := model.Message{ Role: model.MessageRoleToolCall, - // xml stripped from content - Content: content[:start], + // function call xml stripped from content for model interop + Content: strings.TrimSpace(content[:start]), ToolCalls: convertXMLFunctionCallsToToolCalls(functionCalls), } @@ -305,33 +325,36 @@ func (c *AnthropicClient) CreateChatCompletionStream( return "", err } - toolReply := model.Message{ + toolResult := model.Message{ Role: model.MessageRoleToolResult, ToolResults: toolResults, } if callback != nil { callback(toolCall) - callback(toolReply) + callback(toolResult) } - // Recurse into CreateChatCompletionStream with the tool call replies - // added to the original messages - messages = append(append(messages, toolCall), toolReply) + if continuation { + messages[len(messages)-1] = toolCall + } else { + messages = append(messages, toolCall) + } + + messages = append(messages, toolResult) return c.CreateChatCompletionStream(ctx, params, messages, callback, output) } } case "message_stop": // return the completed message + content := sb.String() if callback != nil { - if !isToolCall { - callback(model.Message{ - Role: model.MessageRoleAssistant, - Content: sb.String(), - }) - } + callback(model.Message{ + Role: model.MessageRoleAssistant, + Content: content, + }) } - return sb.String(), nil + return content, nil case "error": return sb.String(), fmt.Errorf("an error occurred: %s", event["error"]) default: diff --git a/pkg/lmcli/provider/openai/openai.go b/pkg/lmcli/provider/openai/openai.go index 89d1309..0b8224b 100644 --- a/pkg/lmcli/provider/openai/openai.go +++ b/pkg/lmcli/provider/openai/openai.go @@ -137,7 +137,15 @@ func handleToolCalls( params model.RequestParameters, content string, toolCalls []openai.ToolCall, + callback provider.ReplyCallback, + messages []model.Message, ) ([]model.Message, error) { + lastMessage := messages[len(messages)-1] + continuation := false + if lastMessage.Role.IsAssistant() { + continuation = true + } + toolCall := model.Message{ Role: model.MessageRoleToolCall, Content: content, @@ -154,7 +162,19 @@ func handleToolCalls( ToolResults: toolResults, } - return []model.Message{toolCall, toolResult}, nil + if callback != nil { + callback(toolCall) + callback(toolResult) + } + + if continuation { + messages[len(messages)-1] = toolCall + } else { + messages = append(messages, toolCall) + } + messages = append(messages, toolResult) + + return messages, nil } func (c *OpenAIClient) CreateChatCompletion( @@ -163,6 +183,10 @@ func (c *OpenAIClient) CreateChatCompletion( messages []model.Message, callback provider.ReplyCallback, ) (string, error) { + if len(messages) == 0 { + return "", fmt.Errorf("Can't create completion from no messages") + } + client := openai.NewClient(c.APIKey) req := createChatCompletionRequest(c, params, messages) resp, err := client.CreateChatCompletion(ctx, req) @@ -172,32 +196,33 @@ func (c *OpenAIClient) CreateChatCompletion( choice := resp.Choices[0] + var content string + lastMessage := messages[len(messages)-1] + if lastMessage.Role.IsAssistant() { + content = lastMessage.Content + choice.Message.Content + } else { + content = choice.Message.Content + } + toolCalls := choice.Message.ToolCalls if len(toolCalls) > 0 { - results, err := handleToolCalls(params, choice.Message.Content, toolCalls) + messages, err := handleToolCalls(params, content, toolCalls, callback, messages) if err != nil { - return "", err - } - if callback != nil { - for _, result := range results { - callback(result) - } + return content, err } - // Recurse into CreateChatCompletion with the tool call replies - messages = append(messages, results...) return c.CreateChatCompletion(ctx, params, messages, callback) } if callback != nil { callback(model.Message{ Role: model.MessageRoleAssistant, - Content: choice.Message.Content, + Content: content, }) } // Return the user-facing message. - return choice.Message.Content, nil + return content, nil } func (c *OpenAIClient) CreateChatCompletionStream( @@ -207,6 +232,10 @@ func (c *OpenAIClient) CreateChatCompletionStream( callback provider.ReplyCallback, output chan<- string, ) (string, error) { + if len(messages) == 0 { + return "", fmt.Errorf("Can't create completion from no messages") + } + client := openai.NewClient(c.APIKey) req := createChatCompletionRequest(c, params, messages) @@ -219,6 +248,11 @@ func (c *OpenAIClient) CreateChatCompletionStream( content := strings.Builder{} toolCalls := []openai.ToolCall{} + lastMessage := messages[len(messages)-1] + if lastMessage.Role.IsAssistant() { + content.WriteString(lastMessage.Content) + } + // Iterate stream segments for { response, e := stream.Recv() @@ -251,19 +285,12 @@ func (c *OpenAIClient) CreateChatCompletionStream( } if len(toolCalls) > 0 { - results, err := handleToolCalls(params, content.String(), toolCalls) + messages, err := handleToolCalls(params, content.String(), toolCalls, callback, messages) if err != nil { return content.String(), err } - if callback != nil { - for _, result := range results { - callback(result) - } - } - // Recurse into CreateChatCompletionStream with the tool call replies - messages = append(messages, results...) return c.CreateChatCompletionStream(ctx, params, messages, callback, output) } else { if callback != nil { diff --git a/pkg/tui/tui.go b/pkg/tui/tui.go index bb5cb51..c82019b 100644 --- a/pkg/tui/tui.go +++ b/pkg/tui/tui.go @@ -152,6 +152,7 @@ func (m model) Update(msg tea.Msg) (tea.Model, tea.Cmd) { case "ctrl+c": if m.waitingForReply { m.stopSignal <- "" + return m, nil } else { return m, tea.Quit } @@ -192,7 +193,7 @@ func (m model) Update(msg tea.Msg) (tea.Model, tea.Cmd) { case msgResponseChunk: chunk := string(msg) last := len(m.messages) - 1 - if last >= 0 && m.messages[last].Role == models.MessageRoleAssistant { + if last >= 0 && m.messages[last].Role.IsAssistant() { m.setMessageContents(last, m.messages[last].Content+chunk) } else { m.addMessage(models.Message{ @@ -205,17 +206,16 @@ func (m model) Update(msg tea.Msg) (tea.Model, tea.Cmd) { case msgAssistantReply: // the last reply that was being worked on is finished reply := models.Message(msg) + reply.Content = strings.TrimSpace(reply.Content) + last := len(m.messages) - 1 if last < 0 { - panic("Unexpected empty messages handling msgReply") + panic("Unexpected empty messages handling msgAssistantReply") } - m.setMessageContents(last, strings.TrimSpace(m.messages[last].Content)) - if m.messages[last].Role == models.MessageRoleAssistant { - // the last message was an assistant message, so this is a continuation - if reply.Role == models.MessageRoleToolCall { - // update last message rrole to tool call - m.messages[last].Role = models.MessageRoleToolCall - } + + if reply.Role.IsAssistant() && m.messages[last].Role.IsAssistant() { + // this was a continuation, so replace the previous message with the completed reply + m.setMessage(last, reply) } else { m.addMessage(reply) } @@ -496,9 +496,8 @@ func (m *model) handleMessagesKey(msg tea.KeyMsg) tea.Cmd { scrollIntoView(&m.content, offset, 0.1) } case "ctrl+r": - // resubmit the conversation with all messages up until and including - // the selected message - if len(m.messages) == 0 { + // resubmit the conversation with all messages up until and including the selected message + if m.waitingForReply || len(m.messages) == 0 { return nil } m.messages = m.messages[:m.selectedMessage+1] @@ -543,13 +542,12 @@ func (m *model) handleInputKey(msg tea.KeyMsg) tea.Cmd { return wrapError(err) } - // ensure all messages up to the one we're about to add are - // persistent + // ensure all messages up to the one we're about to add are persisted cmd := m.persistConversation() if cmd != nil { return cmd } - // persist our new message, returning with any possible errors + savedReply, err := m.ctx.Store.AddReply(m.conversation, reply) if err != nil { return wrapError(err)