Update CreateChatCompletion behavior
When the last message in the passed messages slice is an assistant message, treat it as a partial message that is being continued, and include its content in the newly created reply Update TUI code to handle new behavior
This commit is contained in:
parent
3185b2d7d6
commit
91c74d9e1e
|
@ -41,18 +41,28 @@ type RequestParameters struct {
|
||||||
ToolBag []Tool
|
ToolBag []Tool
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (m *MessageRole) IsAssistant() bool {
|
||||||
|
switch *m {
|
||||||
|
case MessageRoleAssistant, MessageRoleToolCall:
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
// FriendlyRole returns a human friendly signifier for the message's role.
|
// FriendlyRole returns a human friendly signifier for the message's role.
|
||||||
func (m *MessageRole) FriendlyRole() string {
|
func (m *MessageRole) FriendlyRole() string {
|
||||||
var friendlyRole string
|
|
||||||
switch *m {
|
switch *m {
|
||||||
case MessageRoleUser:
|
case MessageRoleUser:
|
||||||
friendlyRole = "You"
|
return "You"
|
||||||
case MessageRoleSystem:
|
case MessageRoleSystem:
|
||||||
friendlyRole = "System"
|
return "System"
|
||||||
case MessageRoleAssistant:
|
case MessageRoleAssistant:
|
||||||
friendlyRole = "Assistant"
|
return "Assistant"
|
||||||
|
case MessageRoleToolCall:
|
||||||
|
return "Tool Call"
|
||||||
|
case MessageRoleToolResult:
|
||||||
|
return "Tool Result"
|
||||||
default:
|
default:
|
||||||
friendlyRole = string(*m)
|
return string(*m)
|
||||||
}
|
}
|
||||||
return friendlyRole
|
|
||||||
}
|
}
|
||||||
|
|
|
@ -46,6 +46,8 @@ type Response struct {
|
||||||
Type string `json:"type"`
|
Type string `json:"type"`
|
||||||
Role string `json:"role"`
|
Role string `json:"role"`
|
||||||
Content []OriginalContent `json:"content"`
|
Content []OriginalContent `json:"content"`
|
||||||
|
StopReason string `json:"stop_reason"`
|
||||||
|
StopSequence string `json:"stop_sequence"`
|
||||||
}
|
}
|
||||||
|
|
||||||
const FUNCTION_STOP_SEQUENCE = "</function_calls>"
|
const FUNCTION_STOP_SEQUENCE = "</function_calls>"
|
||||||
|
@ -147,6 +149,10 @@ func (c *AnthropicClient) CreateChatCompletion(
|
||||||
messages []model.Message,
|
messages []model.Message,
|
||||||
callback provider.ReplyCallback,
|
callback provider.ReplyCallback,
|
||||||
) (string, error) {
|
) (string, error) {
|
||||||
|
if len(messages) == 0 {
|
||||||
|
return "", fmt.Errorf("Can't create completion from no messages")
|
||||||
|
}
|
||||||
|
|
||||||
request := buildRequest(params, messages)
|
request := buildRequest(params, messages)
|
||||||
|
|
||||||
resp, err := sendRequest(ctx, c, request)
|
resp, err := sendRequest(ctx, c, request)
|
||||||
|
@ -162,6 +168,14 @@ func (c *AnthropicClient) CreateChatCompletion(
|
||||||
}
|
}
|
||||||
|
|
||||||
sb := strings.Builder{}
|
sb := strings.Builder{}
|
||||||
|
|
||||||
|
lastMessage := messages[len(messages)-1]
|
||||||
|
if lastMessage.Role.IsAssistant() {
|
||||||
|
// this is a continuation of a previous assistant reply, so we'll
|
||||||
|
// include its contents in the final result
|
||||||
|
sb.WriteString(lastMessage.Content)
|
||||||
|
}
|
||||||
|
|
||||||
for _, content := range response.Content {
|
for _, content := range response.Content {
|
||||||
var reply model.Message
|
var reply model.Message
|
||||||
switch content.Type {
|
switch content.Type {
|
||||||
|
@ -189,6 +203,10 @@ func (c *AnthropicClient) CreateChatCompletionStream(
|
||||||
callback provider.ReplyCallback,
|
callback provider.ReplyCallback,
|
||||||
output chan<- string,
|
output chan<- string,
|
||||||
) (string, error) {
|
) (string, error) {
|
||||||
|
if len(messages) == 0 {
|
||||||
|
return "", fmt.Errorf("Can't create completion from no messages")
|
||||||
|
}
|
||||||
|
|
||||||
request := buildRequest(params, messages)
|
request := buildRequest(params, messages)
|
||||||
request.Stream = true
|
request.Stream = true
|
||||||
|
|
||||||
|
@ -198,11 +216,18 @@ func (c *AnthropicClient) CreateChatCompletionStream(
|
||||||
}
|
}
|
||||||
defer resp.Body.Close()
|
defer resp.Body.Close()
|
||||||
|
|
||||||
scanner := bufio.NewScanner(resp.Body)
|
|
||||||
sb := strings.Builder{}
|
sb := strings.Builder{}
|
||||||
|
|
||||||
isToolCall := false
|
lastMessage := messages[len(messages)-1]
|
||||||
|
continuation := false
|
||||||
|
if messages[len(messages)-1].Role.IsAssistant() {
|
||||||
|
// this is a continuation of a previous assistant reply, so we'll
|
||||||
|
// include its contents in the final result
|
||||||
|
sb.WriteString(lastMessage.Content)
|
||||||
|
continuation = true
|
||||||
|
}
|
||||||
|
|
||||||
|
scanner := bufio.NewScanner(resp.Body)
|
||||||
for scanner.Scan() {
|
for scanner.Scan() {
|
||||||
line := scanner.Text()
|
line := scanner.Text()
|
||||||
line = strings.TrimSpace(line)
|
line = strings.TrimSpace(line)
|
||||||
|
@ -277,26 +302,21 @@ func (c *AnthropicClient) CreateChatCompletionStream(
|
||||||
return content, fmt.Errorf("reached </function_calls> stop sequence but no opening tag found")
|
return content, fmt.Errorf("reached </function_calls> stop sequence but no opening tag found")
|
||||||
}
|
}
|
||||||
|
|
||||||
isToolCall = true
|
|
||||||
|
|
||||||
funcCallXml := content[start:]
|
|
||||||
funcCallXml += FUNCTION_STOP_SEQUENCE
|
|
||||||
|
|
||||||
sb.WriteString(FUNCTION_STOP_SEQUENCE)
|
sb.WriteString(FUNCTION_STOP_SEQUENCE)
|
||||||
output <- FUNCTION_STOP_SEQUENCE
|
output <- FUNCTION_STOP_SEQUENCE
|
||||||
|
|
||||||
// Extract function calls
|
funcCallXml := content[start:] + FUNCTION_STOP_SEQUENCE
|
||||||
|
|
||||||
var functionCalls XMLFunctionCalls
|
var functionCalls XMLFunctionCalls
|
||||||
err := xml.Unmarshal([]byte(sb.String()), &functionCalls)
|
err := xml.Unmarshal([]byte(funcCallXml), &functionCalls)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return "", fmt.Errorf("failed to unmarshal function_calls: %v", err)
|
return "", fmt.Errorf("failed to unmarshal function_calls: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Execute function calls
|
|
||||||
toolCall := model.Message{
|
toolCall := model.Message{
|
||||||
Role: model.MessageRoleToolCall,
|
Role: model.MessageRoleToolCall,
|
||||||
// xml stripped from content
|
// function call xml stripped from content for model interop
|
||||||
Content: content[:start],
|
Content: strings.TrimSpace(content[:start]),
|
||||||
ToolCalls: convertXMLFunctionCallsToToolCalls(functionCalls),
|
ToolCalls: convertXMLFunctionCallsToToolCalls(functionCalls),
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -305,33 +325,36 @@ func (c *AnthropicClient) CreateChatCompletionStream(
|
||||||
return "", err
|
return "", err
|
||||||
}
|
}
|
||||||
|
|
||||||
toolReply := model.Message{
|
toolResult := model.Message{
|
||||||
Role: model.MessageRoleToolResult,
|
Role: model.MessageRoleToolResult,
|
||||||
ToolResults: toolResults,
|
ToolResults: toolResults,
|
||||||
}
|
}
|
||||||
|
|
||||||
if callback != nil {
|
if callback != nil {
|
||||||
callback(toolCall)
|
callback(toolCall)
|
||||||
callback(toolReply)
|
callback(toolResult)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Recurse into CreateChatCompletionStream with the tool call replies
|
if continuation {
|
||||||
// added to the original messages
|
messages[len(messages)-1] = toolCall
|
||||||
messages = append(append(messages, toolCall), toolReply)
|
} else {
|
||||||
|
messages = append(messages, toolCall)
|
||||||
|
}
|
||||||
|
|
||||||
|
messages = append(messages, toolResult)
|
||||||
return c.CreateChatCompletionStream(ctx, params, messages, callback, output)
|
return c.CreateChatCompletionStream(ctx, params, messages, callback, output)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
case "message_stop":
|
case "message_stop":
|
||||||
// return the completed message
|
// return the completed message
|
||||||
|
content := sb.String()
|
||||||
if callback != nil {
|
if callback != nil {
|
||||||
if !isToolCall {
|
|
||||||
callback(model.Message{
|
callback(model.Message{
|
||||||
Role: model.MessageRoleAssistant,
|
Role: model.MessageRoleAssistant,
|
||||||
Content: sb.String(),
|
Content: content,
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
return content, nil
|
||||||
return sb.String(), nil
|
|
||||||
case "error":
|
case "error":
|
||||||
return sb.String(), fmt.Errorf("an error occurred: %s", event["error"])
|
return sb.String(), fmt.Errorf("an error occurred: %s", event["error"])
|
||||||
default:
|
default:
|
||||||
|
|
|
@ -137,7 +137,15 @@ func handleToolCalls(
|
||||||
params model.RequestParameters,
|
params model.RequestParameters,
|
||||||
content string,
|
content string,
|
||||||
toolCalls []openai.ToolCall,
|
toolCalls []openai.ToolCall,
|
||||||
|
callback provider.ReplyCallback,
|
||||||
|
messages []model.Message,
|
||||||
) ([]model.Message, error) {
|
) ([]model.Message, error) {
|
||||||
|
lastMessage := messages[len(messages)-1]
|
||||||
|
continuation := false
|
||||||
|
if lastMessage.Role.IsAssistant() {
|
||||||
|
continuation = true
|
||||||
|
}
|
||||||
|
|
||||||
toolCall := model.Message{
|
toolCall := model.Message{
|
||||||
Role: model.MessageRoleToolCall,
|
Role: model.MessageRoleToolCall,
|
||||||
Content: content,
|
Content: content,
|
||||||
|
@ -154,7 +162,19 @@ func handleToolCalls(
|
||||||
ToolResults: toolResults,
|
ToolResults: toolResults,
|
||||||
}
|
}
|
||||||
|
|
||||||
return []model.Message{toolCall, toolResult}, nil
|
if callback != nil {
|
||||||
|
callback(toolCall)
|
||||||
|
callback(toolResult)
|
||||||
|
}
|
||||||
|
|
||||||
|
if continuation {
|
||||||
|
messages[len(messages)-1] = toolCall
|
||||||
|
} else {
|
||||||
|
messages = append(messages, toolCall)
|
||||||
|
}
|
||||||
|
messages = append(messages, toolResult)
|
||||||
|
|
||||||
|
return messages, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *OpenAIClient) CreateChatCompletion(
|
func (c *OpenAIClient) CreateChatCompletion(
|
||||||
|
@ -163,6 +183,10 @@ func (c *OpenAIClient) CreateChatCompletion(
|
||||||
messages []model.Message,
|
messages []model.Message,
|
||||||
callback provider.ReplyCallback,
|
callback provider.ReplyCallback,
|
||||||
) (string, error) {
|
) (string, error) {
|
||||||
|
if len(messages) == 0 {
|
||||||
|
return "", fmt.Errorf("Can't create completion from no messages")
|
||||||
|
}
|
||||||
|
|
||||||
client := openai.NewClient(c.APIKey)
|
client := openai.NewClient(c.APIKey)
|
||||||
req := createChatCompletionRequest(c, params, messages)
|
req := createChatCompletionRequest(c, params, messages)
|
||||||
resp, err := client.CreateChatCompletion(ctx, req)
|
resp, err := client.CreateChatCompletion(ctx, req)
|
||||||
|
@ -172,32 +196,33 @@ func (c *OpenAIClient) CreateChatCompletion(
|
||||||
|
|
||||||
choice := resp.Choices[0]
|
choice := resp.Choices[0]
|
||||||
|
|
||||||
toolCalls := choice.Message.ToolCalls
|
var content string
|
||||||
if len(toolCalls) > 0 {
|
lastMessage := messages[len(messages)-1]
|
||||||
results, err := handleToolCalls(params, choice.Message.Content, toolCalls)
|
if lastMessage.Role.IsAssistant() {
|
||||||
if err != nil {
|
content = lastMessage.Content + choice.Message.Content
|
||||||
return "", err
|
} else {
|
||||||
}
|
content = choice.Message.Content
|
||||||
if callback != nil {
|
}
|
||||||
for _, result := range results {
|
|
||||||
callback(result)
|
toolCalls := choice.Message.ToolCalls
|
||||||
}
|
if len(toolCalls) > 0 {
|
||||||
|
messages, err := handleToolCalls(params, content, toolCalls, callback, messages)
|
||||||
|
if err != nil {
|
||||||
|
return content, err
|
||||||
}
|
}
|
||||||
|
|
||||||
// Recurse into CreateChatCompletion with the tool call replies
|
|
||||||
messages = append(messages, results...)
|
|
||||||
return c.CreateChatCompletion(ctx, params, messages, callback)
|
return c.CreateChatCompletion(ctx, params, messages, callback)
|
||||||
}
|
}
|
||||||
|
|
||||||
if callback != nil {
|
if callback != nil {
|
||||||
callback(model.Message{
|
callback(model.Message{
|
||||||
Role: model.MessageRoleAssistant,
|
Role: model.MessageRoleAssistant,
|
||||||
Content: choice.Message.Content,
|
Content: content,
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
// Return the user-facing message.
|
// Return the user-facing message.
|
||||||
return choice.Message.Content, nil
|
return content, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *OpenAIClient) CreateChatCompletionStream(
|
func (c *OpenAIClient) CreateChatCompletionStream(
|
||||||
|
@ -207,6 +232,10 @@ func (c *OpenAIClient) CreateChatCompletionStream(
|
||||||
callback provider.ReplyCallback,
|
callback provider.ReplyCallback,
|
||||||
output chan<- string,
|
output chan<- string,
|
||||||
) (string, error) {
|
) (string, error) {
|
||||||
|
if len(messages) == 0 {
|
||||||
|
return "", fmt.Errorf("Can't create completion from no messages")
|
||||||
|
}
|
||||||
|
|
||||||
client := openai.NewClient(c.APIKey)
|
client := openai.NewClient(c.APIKey)
|
||||||
req := createChatCompletionRequest(c, params, messages)
|
req := createChatCompletionRequest(c, params, messages)
|
||||||
|
|
||||||
|
@ -219,6 +248,11 @@ func (c *OpenAIClient) CreateChatCompletionStream(
|
||||||
content := strings.Builder{}
|
content := strings.Builder{}
|
||||||
toolCalls := []openai.ToolCall{}
|
toolCalls := []openai.ToolCall{}
|
||||||
|
|
||||||
|
lastMessage := messages[len(messages)-1]
|
||||||
|
if lastMessage.Role.IsAssistant() {
|
||||||
|
content.WriteString(lastMessage.Content)
|
||||||
|
}
|
||||||
|
|
||||||
// Iterate stream segments
|
// Iterate stream segments
|
||||||
for {
|
for {
|
||||||
response, e := stream.Recv()
|
response, e := stream.Recv()
|
||||||
|
@ -251,19 +285,12 @@ func (c *OpenAIClient) CreateChatCompletionStream(
|
||||||
}
|
}
|
||||||
|
|
||||||
if len(toolCalls) > 0 {
|
if len(toolCalls) > 0 {
|
||||||
results, err := handleToolCalls(params, content.String(), toolCalls)
|
messages, err := handleToolCalls(params, content.String(), toolCalls, callback, messages)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return content.String(), err
|
return content.String(), err
|
||||||
}
|
}
|
||||||
|
|
||||||
if callback != nil {
|
|
||||||
for _, result := range results {
|
|
||||||
callback(result)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Recurse into CreateChatCompletionStream with the tool call replies
|
// Recurse into CreateChatCompletionStream with the tool call replies
|
||||||
messages = append(messages, results...)
|
|
||||||
return c.CreateChatCompletionStream(ctx, params, messages, callback, output)
|
return c.CreateChatCompletionStream(ctx, params, messages, callback, output)
|
||||||
} else {
|
} else {
|
||||||
if callback != nil {
|
if callback != nil {
|
||||||
|
|
|
@ -152,6 +152,7 @@ func (m model) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
|
||||||
case "ctrl+c":
|
case "ctrl+c":
|
||||||
if m.waitingForReply {
|
if m.waitingForReply {
|
||||||
m.stopSignal <- ""
|
m.stopSignal <- ""
|
||||||
|
return m, nil
|
||||||
} else {
|
} else {
|
||||||
return m, tea.Quit
|
return m, tea.Quit
|
||||||
}
|
}
|
||||||
|
@ -192,7 +193,7 @@ func (m model) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
|
||||||
case msgResponseChunk:
|
case msgResponseChunk:
|
||||||
chunk := string(msg)
|
chunk := string(msg)
|
||||||
last := len(m.messages) - 1
|
last := len(m.messages) - 1
|
||||||
if last >= 0 && m.messages[last].Role == models.MessageRoleAssistant {
|
if last >= 0 && m.messages[last].Role.IsAssistant() {
|
||||||
m.setMessageContents(last, m.messages[last].Content+chunk)
|
m.setMessageContents(last, m.messages[last].Content+chunk)
|
||||||
} else {
|
} else {
|
||||||
m.addMessage(models.Message{
|
m.addMessage(models.Message{
|
||||||
|
@ -205,17 +206,16 @@ func (m model) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
|
||||||
case msgAssistantReply:
|
case msgAssistantReply:
|
||||||
// the last reply that was being worked on is finished
|
// the last reply that was being worked on is finished
|
||||||
reply := models.Message(msg)
|
reply := models.Message(msg)
|
||||||
|
reply.Content = strings.TrimSpace(reply.Content)
|
||||||
|
|
||||||
last := len(m.messages) - 1
|
last := len(m.messages) - 1
|
||||||
if last < 0 {
|
if last < 0 {
|
||||||
panic("Unexpected empty messages handling msgReply")
|
panic("Unexpected empty messages handling msgAssistantReply")
|
||||||
}
|
|
||||||
m.setMessageContents(last, strings.TrimSpace(m.messages[last].Content))
|
|
||||||
if m.messages[last].Role == models.MessageRoleAssistant {
|
|
||||||
// the last message was an assistant message, so this is a continuation
|
|
||||||
if reply.Role == models.MessageRoleToolCall {
|
|
||||||
// update last message rrole to tool call
|
|
||||||
m.messages[last].Role = models.MessageRoleToolCall
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if reply.Role.IsAssistant() && m.messages[last].Role.IsAssistant() {
|
||||||
|
// this was a continuation, so replace the previous message with the completed reply
|
||||||
|
m.setMessage(last, reply)
|
||||||
} else {
|
} else {
|
||||||
m.addMessage(reply)
|
m.addMessage(reply)
|
||||||
}
|
}
|
||||||
|
@ -496,9 +496,8 @@ func (m *model) handleMessagesKey(msg tea.KeyMsg) tea.Cmd {
|
||||||
scrollIntoView(&m.content, offset, 0.1)
|
scrollIntoView(&m.content, offset, 0.1)
|
||||||
}
|
}
|
||||||
case "ctrl+r":
|
case "ctrl+r":
|
||||||
// resubmit the conversation with all messages up until and including
|
// resubmit the conversation with all messages up until and including the selected message
|
||||||
// the selected message
|
if m.waitingForReply || len(m.messages) == 0 {
|
||||||
if len(m.messages) == 0 {
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
m.messages = m.messages[:m.selectedMessage+1]
|
m.messages = m.messages[:m.selectedMessage+1]
|
||||||
|
@ -543,13 +542,12 @@ func (m *model) handleInputKey(msg tea.KeyMsg) tea.Cmd {
|
||||||
return wrapError(err)
|
return wrapError(err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// ensure all messages up to the one we're about to add are
|
// ensure all messages up to the one we're about to add are persisted
|
||||||
// persistent
|
|
||||||
cmd := m.persistConversation()
|
cmd := m.persistConversation()
|
||||||
if cmd != nil {
|
if cmd != nil {
|
||||||
return cmd
|
return cmd
|
||||||
}
|
}
|
||||||
// persist our new message, returning with any possible errors
|
|
||||||
savedReply, err := m.ctx.Store.AddReply(m.conversation, reply)
|
savedReply, err := m.ctx.Store.AddReply(m.conversation, reply)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return wrapError(err)
|
return wrapError(err)
|
||||||
|
|
Loading…
Reference in New Issue