Update CreateChatCompletion behavior
When the last message in the passed messages slice is an assistant message, treat it as a partial message that is being continued, and include its content in the newly created reply Update TUI code to handle new behavior
This commit is contained in:
parent
3185b2d7d6
commit
91c74d9e1e
@ -41,18 +41,28 @@ type RequestParameters struct {
|
||||
ToolBag []Tool
|
||||
}
|
||||
|
||||
func (m *MessageRole) IsAssistant() bool {
|
||||
switch *m {
|
||||
case MessageRoleAssistant, MessageRoleToolCall:
|
||||
return true
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// FriendlyRole returns a human friendly signifier for the message's role.
|
||||
func (m *MessageRole) FriendlyRole() string {
|
||||
var friendlyRole string
|
||||
switch *m {
|
||||
case MessageRoleUser:
|
||||
friendlyRole = "You"
|
||||
return "You"
|
||||
case MessageRoleSystem:
|
||||
friendlyRole = "System"
|
||||
return "System"
|
||||
case MessageRoleAssistant:
|
||||
friendlyRole = "Assistant"
|
||||
return "Assistant"
|
||||
case MessageRoleToolCall:
|
||||
return "Tool Call"
|
||||
case MessageRoleToolResult:
|
||||
return "Tool Result"
|
||||
default:
|
||||
friendlyRole = string(*m)
|
||||
return string(*m)
|
||||
}
|
||||
return friendlyRole
|
||||
}
|
||||
|
@ -42,10 +42,12 @@ type OriginalContent struct {
|
||||
}
|
||||
|
||||
type Response struct {
|
||||
Id string `json:"id"`
|
||||
Type string `json:"type"`
|
||||
Role string `json:"role"`
|
||||
Content []OriginalContent `json:"content"`
|
||||
Id string `json:"id"`
|
||||
Type string `json:"type"`
|
||||
Role string `json:"role"`
|
||||
Content []OriginalContent `json:"content"`
|
||||
StopReason string `json:"stop_reason"`
|
||||
StopSequence string `json:"stop_sequence"`
|
||||
}
|
||||
|
||||
const FUNCTION_STOP_SEQUENCE = "</function_calls>"
|
||||
@ -147,6 +149,10 @@ func (c *AnthropicClient) CreateChatCompletion(
|
||||
messages []model.Message,
|
||||
callback provider.ReplyCallback,
|
||||
) (string, error) {
|
||||
if len(messages) == 0 {
|
||||
return "", fmt.Errorf("Can't create completion from no messages")
|
||||
}
|
||||
|
||||
request := buildRequest(params, messages)
|
||||
|
||||
resp, err := sendRequest(ctx, c, request)
|
||||
@ -162,6 +168,14 @@ func (c *AnthropicClient) CreateChatCompletion(
|
||||
}
|
||||
|
||||
sb := strings.Builder{}
|
||||
|
||||
lastMessage := messages[len(messages)-1]
|
||||
if lastMessage.Role.IsAssistant() {
|
||||
// this is a continuation of a previous assistant reply, so we'll
|
||||
// include its contents in the final result
|
||||
sb.WriteString(lastMessage.Content)
|
||||
}
|
||||
|
||||
for _, content := range response.Content {
|
||||
var reply model.Message
|
||||
switch content.Type {
|
||||
@ -189,6 +203,10 @@ func (c *AnthropicClient) CreateChatCompletionStream(
|
||||
callback provider.ReplyCallback,
|
||||
output chan<- string,
|
||||
) (string, error) {
|
||||
if len(messages) == 0 {
|
||||
return "", fmt.Errorf("Can't create completion from no messages")
|
||||
}
|
||||
|
||||
request := buildRequest(params, messages)
|
||||
request.Stream = true
|
||||
|
||||
@ -198,11 +216,18 @@ func (c *AnthropicClient) CreateChatCompletionStream(
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
scanner := bufio.NewScanner(resp.Body)
|
||||
sb := strings.Builder{}
|
||||
|
||||
isToolCall := false
|
||||
lastMessage := messages[len(messages)-1]
|
||||
continuation := false
|
||||
if messages[len(messages)-1].Role.IsAssistant() {
|
||||
// this is a continuation of a previous assistant reply, so we'll
|
||||
// include its contents in the final result
|
||||
sb.WriteString(lastMessage.Content)
|
||||
continuation = true
|
||||
}
|
||||
|
||||
scanner := bufio.NewScanner(resp.Body)
|
||||
for scanner.Scan() {
|
||||
line := scanner.Text()
|
||||
line = strings.TrimSpace(line)
|
||||
@ -277,26 +302,21 @@ func (c *AnthropicClient) CreateChatCompletionStream(
|
||||
return content, fmt.Errorf("reached </function_calls> stop sequence but no opening tag found")
|
||||
}
|
||||
|
||||
isToolCall = true
|
||||
|
||||
funcCallXml := content[start:]
|
||||
funcCallXml += FUNCTION_STOP_SEQUENCE
|
||||
|
||||
sb.WriteString(FUNCTION_STOP_SEQUENCE)
|
||||
output <- FUNCTION_STOP_SEQUENCE
|
||||
|
||||
// Extract function calls
|
||||
funcCallXml := content[start:] + FUNCTION_STOP_SEQUENCE
|
||||
|
||||
var functionCalls XMLFunctionCalls
|
||||
err := xml.Unmarshal([]byte(sb.String()), &functionCalls)
|
||||
err := xml.Unmarshal([]byte(funcCallXml), &functionCalls)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to unmarshal function_calls: %v", err)
|
||||
}
|
||||
|
||||
// Execute function calls
|
||||
toolCall := model.Message{
|
||||
Role: model.MessageRoleToolCall,
|
||||
// xml stripped from content
|
||||
Content: content[:start],
|
||||
// function call xml stripped from content for model interop
|
||||
Content: strings.TrimSpace(content[:start]),
|
||||
ToolCalls: convertXMLFunctionCallsToToolCalls(functionCalls),
|
||||
}
|
||||
|
||||
@ -305,33 +325,36 @@ func (c *AnthropicClient) CreateChatCompletionStream(
|
||||
return "", err
|
||||
}
|
||||
|
||||
toolReply := model.Message{
|
||||
toolResult := model.Message{
|
||||
Role: model.MessageRoleToolResult,
|
||||
ToolResults: toolResults,
|
||||
}
|
||||
|
||||
if callback != nil {
|
||||
callback(toolCall)
|
||||
callback(toolReply)
|
||||
callback(toolResult)
|
||||
}
|
||||
|
||||
// Recurse into CreateChatCompletionStream with the tool call replies
|
||||
// added to the original messages
|
||||
messages = append(append(messages, toolCall), toolReply)
|
||||
if continuation {
|
||||
messages[len(messages)-1] = toolCall
|
||||
} else {
|
||||
messages = append(messages, toolCall)
|
||||
}
|
||||
|
||||
messages = append(messages, toolResult)
|
||||
return c.CreateChatCompletionStream(ctx, params, messages, callback, output)
|
||||
}
|
||||
}
|
||||
case "message_stop":
|
||||
// return the completed message
|
||||
content := sb.String()
|
||||
if callback != nil {
|
||||
if !isToolCall {
|
||||
callback(model.Message{
|
||||
Role: model.MessageRoleAssistant,
|
||||
Content: sb.String(),
|
||||
})
|
||||
}
|
||||
callback(model.Message{
|
||||
Role: model.MessageRoleAssistant,
|
||||
Content: content,
|
||||
})
|
||||
}
|
||||
return sb.String(), nil
|
||||
return content, nil
|
||||
case "error":
|
||||
return sb.String(), fmt.Errorf("an error occurred: %s", event["error"])
|
||||
default:
|
||||
|
@ -137,7 +137,15 @@ func handleToolCalls(
|
||||
params model.RequestParameters,
|
||||
content string,
|
||||
toolCalls []openai.ToolCall,
|
||||
callback provider.ReplyCallback,
|
||||
messages []model.Message,
|
||||
) ([]model.Message, error) {
|
||||
lastMessage := messages[len(messages)-1]
|
||||
continuation := false
|
||||
if lastMessage.Role.IsAssistant() {
|
||||
continuation = true
|
||||
}
|
||||
|
||||
toolCall := model.Message{
|
||||
Role: model.MessageRoleToolCall,
|
||||
Content: content,
|
||||
@ -154,7 +162,19 @@ func handleToolCalls(
|
||||
ToolResults: toolResults,
|
||||
}
|
||||
|
||||
return []model.Message{toolCall, toolResult}, nil
|
||||
if callback != nil {
|
||||
callback(toolCall)
|
||||
callback(toolResult)
|
||||
}
|
||||
|
||||
if continuation {
|
||||
messages[len(messages)-1] = toolCall
|
||||
} else {
|
||||
messages = append(messages, toolCall)
|
||||
}
|
||||
messages = append(messages, toolResult)
|
||||
|
||||
return messages, nil
|
||||
}
|
||||
|
||||
func (c *OpenAIClient) CreateChatCompletion(
|
||||
@ -163,6 +183,10 @@ func (c *OpenAIClient) CreateChatCompletion(
|
||||
messages []model.Message,
|
||||
callback provider.ReplyCallback,
|
||||
) (string, error) {
|
||||
if len(messages) == 0 {
|
||||
return "", fmt.Errorf("Can't create completion from no messages")
|
||||
}
|
||||
|
||||
client := openai.NewClient(c.APIKey)
|
||||
req := createChatCompletionRequest(c, params, messages)
|
||||
resp, err := client.CreateChatCompletion(ctx, req)
|
||||
@ -172,32 +196,33 @@ func (c *OpenAIClient) CreateChatCompletion(
|
||||
|
||||
choice := resp.Choices[0]
|
||||
|
||||
var content string
|
||||
lastMessage := messages[len(messages)-1]
|
||||
if lastMessage.Role.IsAssistant() {
|
||||
content = lastMessage.Content + choice.Message.Content
|
||||
} else {
|
||||
content = choice.Message.Content
|
||||
}
|
||||
|
||||
toolCalls := choice.Message.ToolCalls
|
||||
if len(toolCalls) > 0 {
|
||||
results, err := handleToolCalls(params, choice.Message.Content, toolCalls)
|
||||
messages, err := handleToolCalls(params, content, toolCalls, callback, messages)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
if callback != nil {
|
||||
for _, result := range results {
|
||||
callback(result)
|
||||
}
|
||||
return content, err
|
||||
}
|
||||
|
||||
// Recurse into CreateChatCompletion with the tool call replies
|
||||
messages = append(messages, results...)
|
||||
return c.CreateChatCompletion(ctx, params, messages, callback)
|
||||
}
|
||||
|
||||
if callback != nil {
|
||||
callback(model.Message{
|
||||
Role: model.MessageRoleAssistant,
|
||||
Content: choice.Message.Content,
|
||||
Content: content,
|
||||
})
|
||||
}
|
||||
|
||||
// Return the user-facing message.
|
||||
return choice.Message.Content, nil
|
||||
return content, nil
|
||||
}
|
||||
|
||||
func (c *OpenAIClient) CreateChatCompletionStream(
|
||||
@ -207,6 +232,10 @@ func (c *OpenAIClient) CreateChatCompletionStream(
|
||||
callback provider.ReplyCallback,
|
||||
output chan<- string,
|
||||
) (string, error) {
|
||||
if len(messages) == 0 {
|
||||
return "", fmt.Errorf("Can't create completion from no messages")
|
||||
}
|
||||
|
||||
client := openai.NewClient(c.APIKey)
|
||||
req := createChatCompletionRequest(c, params, messages)
|
||||
|
||||
@ -219,6 +248,11 @@ func (c *OpenAIClient) CreateChatCompletionStream(
|
||||
content := strings.Builder{}
|
||||
toolCalls := []openai.ToolCall{}
|
||||
|
||||
lastMessage := messages[len(messages)-1]
|
||||
if lastMessage.Role.IsAssistant() {
|
||||
content.WriteString(lastMessage.Content)
|
||||
}
|
||||
|
||||
// Iterate stream segments
|
||||
for {
|
||||
response, e := stream.Recv()
|
||||
@ -251,19 +285,12 @@ func (c *OpenAIClient) CreateChatCompletionStream(
|
||||
}
|
||||
|
||||
if len(toolCalls) > 0 {
|
||||
results, err := handleToolCalls(params, content.String(), toolCalls)
|
||||
messages, err := handleToolCalls(params, content.String(), toolCalls, callback, messages)
|
||||
if err != nil {
|
||||
return content.String(), err
|
||||
}
|
||||
|
||||
if callback != nil {
|
||||
for _, result := range results {
|
||||
callback(result)
|
||||
}
|
||||
}
|
||||
|
||||
// Recurse into CreateChatCompletionStream with the tool call replies
|
||||
messages = append(messages, results...)
|
||||
return c.CreateChatCompletionStream(ctx, params, messages, callback, output)
|
||||
} else {
|
||||
if callback != nil {
|
||||
|
@ -152,6 +152,7 @@ func (m model) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
|
||||
case "ctrl+c":
|
||||
if m.waitingForReply {
|
||||
m.stopSignal <- ""
|
||||
return m, nil
|
||||
} else {
|
||||
return m, tea.Quit
|
||||
}
|
||||
@ -192,7 +193,7 @@ func (m model) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
|
||||
case msgResponseChunk:
|
||||
chunk := string(msg)
|
||||
last := len(m.messages) - 1
|
||||
if last >= 0 && m.messages[last].Role == models.MessageRoleAssistant {
|
||||
if last >= 0 && m.messages[last].Role.IsAssistant() {
|
||||
m.setMessageContents(last, m.messages[last].Content+chunk)
|
||||
} else {
|
||||
m.addMessage(models.Message{
|
||||
@ -205,17 +206,16 @@ func (m model) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
|
||||
case msgAssistantReply:
|
||||
// the last reply that was being worked on is finished
|
||||
reply := models.Message(msg)
|
||||
reply.Content = strings.TrimSpace(reply.Content)
|
||||
|
||||
last := len(m.messages) - 1
|
||||
if last < 0 {
|
||||
panic("Unexpected empty messages handling msgReply")
|
||||
panic("Unexpected empty messages handling msgAssistantReply")
|
||||
}
|
||||
m.setMessageContents(last, strings.TrimSpace(m.messages[last].Content))
|
||||
if m.messages[last].Role == models.MessageRoleAssistant {
|
||||
// the last message was an assistant message, so this is a continuation
|
||||
if reply.Role == models.MessageRoleToolCall {
|
||||
// update last message rrole to tool call
|
||||
m.messages[last].Role = models.MessageRoleToolCall
|
||||
}
|
||||
|
||||
if reply.Role.IsAssistant() && m.messages[last].Role.IsAssistant() {
|
||||
// this was a continuation, so replace the previous message with the completed reply
|
||||
m.setMessage(last, reply)
|
||||
} else {
|
||||
m.addMessage(reply)
|
||||
}
|
||||
@ -496,9 +496,8 @@ func (m *model) handleMessagesKey(msg tea.KeyMsg) tea.Cmd {
|
||||
scrollIntoView(&m.content, offset, 0.1)
|
||||
}
|
||||
case "ctrl+r":
|
||||
// resubmit the conversation with all messages up until and including
|
||||
// the selected message
|
||||
if len(m.messages) == 0 {
|
||||
// resubmit the conversation with all messages up until and including the selected message
|
||||
if m.waitingForReply || len(m.messages) == 0 {
|
||||
return nil
|
||||
}
|
||||
m.messages = m.messages[:m.selectedMessage+1]
|
||||
@ -543,13 +542,12 @@ func (m *model) handleInputKey(msg tea.KeyMsg) tea.Cmd {
|
||||
return wrapError(err)
|
||||
}
|
||||
|
||||
// ensure all messages up to the one we're about to add are
|
||||
// persistent
|
||||
// ensure all messages up to the one we're about to add are persisted
|
||||
cmd := m.persistConversation()
|
||||
if cmd != nil {
|
||||
return cmd
|
||||
}
|
||||
// persist our new message, returning with any possible errors
|
||||
|
||||
savedReply, err := m.ctx.Store.AddReply(m.conversation, reply)
|
||||
if err != nil {
|
||||
return wrapError(err)
|
||||
|
Loading…
Reference in New Issue
Block a user