When the last message in the passed messages slice is an assistant message, treat it as a partial message that is being continued, and include its content in the newly created reply Update TUI code to handle new behavior
372 lines
9.8 KiB
Go
372 lines
9.8 KiB
Go
package anthropic
|
|
|
|
import (
|
|
"bufio"
|
|
"bytes"
|
|
"context"
|
|
"encoding/json"
|
|
"encoding/xml"
|
|
"fmt"
|
|
"net/http"
|
|
"strings"
|
|
|
|
"git.mlow.ca/mlow/lmcli/pkg/lmcli/model"
|
|
"git.mlow.ca/mlow/lmcli/pkg/lmcli/provider"
|
|
"git.mlow.ca/mlow/lmcli/pkg/lmcli/tools"
|
|
)
|
|
|
|
type AnthropicClient struct {
|
|
APIKey string
|
|
}
|
|
|
|
type Message struct {
|
|
Role string `json:"role"`
|
|
Content string `json:"content"`
|
|
}
|
|
|
|
type Request struct {
|
|
Model string `json:"model"`
|
|
Messages []Message `json:"messages"`
|
|
System string `json:"system,omitempty"`
|
|
MaxTokens int `json:"max_tokens,omitempty"`
|
|
StopSequences []string `json:"stop_sequences,omitempty"`
|
|
Stream bool `json:"stream,omitempty"`
|
|
Temperature float32 `json:"temperature,omitempty"`
|
|
//TopP float32 `json:"top_p,omitempty"`
|
|
//TopK float32 `json:"top_k,omitempty"`
|
|
}
|
|
|
|
type OriginalContent struct {
|
|
Type string `json:"type"`
|
|
Text string `json:"text"`
|
|
}
|
|
|
|
type Response struct {
|
|
Id string `json:"id"`
|
|
Type string `json:"type"`
|
|
Role string `json:"role"`
|
|
Content []OriginalContent `json:"content"`
|
|
StopReason string `json:"stop_reason"`
|
|
StopSequence string `json:"stop_sequence"`
|
|
}
|
|
|
|
const FUNCTION_STOP_SEQUENCE = "</function_calls>"
|
|
|
|
func buildRequest(params model.RequestParameters, messages []model.Message) Request {
|
|
requestBody := Request{
|
|
Model: params.Model,
|
|
Messages: make([]Message, len(messages)),
|
|
System: params.SystemPrompt,
|
|
MaxTokens: params.MaxTokens,
|
|
Temperature: params.Temperature,
|
|
Stream: false,
|
|
|
|
StopSequences: []string{
|
|
FUNCTION_STOP_SEQUENCE,
|
|
"\n\nHuman:",
|
|
},
|
|
}
|
|
|
|
startIdx := 0
|
|
if len(messages) > 0 && messages[0].Role == model.MessageRoleSystem {
|
|
requestBody.System = messages[0].Content
|
|
requestBody.Messages = requestBody.Messages[1:]
|
|
startIdx = 1
|
|
}
|
|
|
|
if len(params.ToolBag) > 0 {
|
|
if len(requestBody.System) > 0 {
|
|
// add a divider between existing system prompt and tools
|
|
requestBody.System += "\n\n---\n\n"
|
|
}
|
|
requestBody.System += buildToolsSystemPrompt(params.ToolBag)
|
|
}
|
|
|
|
for i, msg := range messages[startIdx:] {
|
|
message := &requestBody.Messages[i]
|
|
|
|
switch msg.Role {
|
|
case model.MessageRoleToolCall:
|
|
message.Role = "assistant"
|
|
if msg.Content != "" {
|
|
message.Content = msg.Content
|
|
}
|
|
xmlFuncCalls := convertToolCallsToXMLFunctionCalls(msg.ToolCalls)
|
|
xmlString, err := xmlFuncCalls.XMLString()
|
|
if err != nil {
|
|
panic("Could not serialize []ToolCall to XMLFunctionCall")
|
|
}
|
|
if len(message.Content) > 0 {
|
|
message.Content += fmt.Sprintf("\n\n%s", xmlString)
|
|
} else {
|
|
message.Content = xmlString
|
|
}
|
|
case model.MessageRoleToolResult:
|
|
xmlFuncResults := convertToolResultsToXMLFunctionResult(msg.ToolResults)
|
|
xmlString, err := xmlFuncResults.XMLString()
|
|
if err != nil {
|
|
panic("Could not serialize []ToolResult to XMLFunctionResults")
|
|
}
|
|
message.Role = "user"
|
|
message.Content = xmlString
|
|
default:
|
|
message.Role = string(msg.Role)
|
|
message.Content = msg.Content
|
|
}
|
|
}
|
|
return requestBody
|
|
}
|
|
|
|
func sendRequest(ctx context.Context, c *AnthropicClient, r Request) (*http.Response, error) {
|
|
url := "https://api.anthropic.com/v1/messages"
|
|
|
|
jsonBody, err := json.Marshal(r)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to marshal request body: %v", err)
|
|
}
|
|
|
|
req, err := http.NewRequestWithContext(ctx, "POST", url, bytes.NewBuffer(jsonBody))
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to create HTTP request: %v", err)
|
|
}
|
|
|
|
req.Header.Set("x-api-key", c.APIKey)
|
|
req.Header.Set("anthropic-version", "2023-06-01")
|
|
req.Header.Set("content-type", "application/json")
|
|
|
|
client := &http.Client{}
|
|
resp, err := client.Do(req)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to send HTTP request: %v", err)
|
|
}
|
|
|
|
return resp, nil
|
|
}
|
|
|
|
func (c *AnthropicClient) CreateChatCompletion(
|
|
ctx context.Context,
|
|
params model.RequestParameters,
|
|
messages []model.Message,
|
|
callback provider.ReplyCallback,
|
|
) (string, error) {
|
|
if len(messages) == 0 {
|
|
return "", fmt.Errorf("Can't create completion from no messages")
|
|
}
|
|
|
|
request := buildRequest(params, messages)
|
|
|
|
resp, err := sendRequest(ctx, c, request)
|
|
if err != nil {
|
|
return "", err
|
|
}
|
|
defer resp.Body.Close()
|
|
|
|
var response Response
|
|
err = json.NewDecoder(resp.Body).Decode(&response)
|
|
if err != nil {
|
|
return "", fmt.Errorf("failed to decode response: %v", err)
|
|
}
|
|
|
|
sb := strings.Builder{}
|
|
|
|
lastMessage := messages[len(messages)-1]
|
|
if lastMessage.Role.IsAssistant() {
|
|
// this is a continuation of a previous assistant reply, so we'll
|
|
// include its contents in the final result
|
|
sb.WriteString(lastMessage.Content)
|
|
}
|
|
|
|
for _, content := range response.Content {
|
|
var reply model.Message
|
|
switch content.Type {
|
|
case "text":
|
|
reply = model.Message{
|
|
Role: model.MessageRoleAssistant,
|
|
Content: content.Text,
|
|
}
|
|
sb.WriteString(reply.Content)
|
|
default:
|
|
return "", fmt.Errorf("unsupported message type: %s", content.Type)
|
|
}
|
|
if callback != nil {
|
|
callback(reply)
|
|
}
|
|
}
|
|
|
|
return sb.String(), nil
|
|
}
|
|
|
|
func (c *AnthropicClient) CreateChatCompletionStream(
|
|
ctx context.Context,
|
|
params model.RequestParameters,
|
|
messages []model.Message,
|
|
callback provider.ReplyCallback,
|
|
output chan<- string,
|
|
) (string, error) {
|
|
if len(messages) == 0 {
|
|
return "", fmt.Errorf("Can't create completion from no messages")
|
|
}
|
|
|
|
request := buildRequest(params, messages)
|
|
request.Stream = true
|
|
|
|
resp, err := sendRequest(ctx, c, request)
|
|
if err != nil {
|
|
return "", err
|
|
}
|
|
defer resp.Body.Close()
|
|
|
|
sb := strings.Builder{}
|
|
|
|
lastMessage := messages[len(messages)-1]
|
|
continuation := false
|
|
if messages[len(messages)-1].Role.IsAssistant() {
|
|
// this is a continuation of a previous assistant reply, so we'll
|
|
// include its contents in the final result
|
|
sb.WriteString(lastMessage.Content)
|
|
continuation = true
|
|
}
|
|
|
|
scanner := bufio.NewScanner(resp.Body)
|
|
for scanner.Scan() {
|
|
line := scanner.Text()
|
|
line = strings.TrimSpace(line)
|
|
|
|
if len(line) == 0 {
|
|
continue
|
|
}
|
|
|
|
if line[0] == '{' {
|
|
var event map[string]interface{}
|
|
err := json.Unmarshal([]byte(line), &event)
|
|
if err != nil {
|
|
return "", fmt.Errorf("failed to unmarshal event data '%s': %v", line, err)
|
|
}
|
|
eventType, ok := event["type"].(string)
|
|
if !ok {
|
|
return "", fmt.Errorf("invalid event: %s", line)
|
|
}
|
|
switch eventType {
|
|
case "error":
|
|
return sb.String(), fmt.Errorf("an error occurred: %s", event["error"])
|
|
default:
|
|
return sb.String(), fmt.Errorf("unknown event type: %s", eventType)
|
|
}
|
|
} else if strings.HasPrefix(line, "data:") {
|
|
data := strings.TrimSpace(strings.TrimPrefix(line, "data:"))
|
|
var event map[string]interface{}
|
|
err := json.Unmarshal([]byte(data), &event)
|
|
if err != nil {
|
|
return "", fmt.Errorf("failed to unmarshal event data: %v", err)
|
|
}
|
|
|
|
eventType, ok := event["type"].(string)
|
|
if !ok {
|
|
return "", fmt.Errorf("invalid event type")
|
|
}
|
|
|
|
switch eventType {
|
|
case "message_start":
|
|
// noop
|
|
case "ping":
|
|
// write an empty string to signal start of text
|
|
output <- ""
|
|
case "content_block_start":
|
|
// ignore?
|
|
case "content_block_delta":
|
|
delta, ok := event["delta"].(map[string]interface{})
|
|
if !ok {
|
|
return "", fmt.Errorf("invalid content block delta")
|
|
}
|
|
text, ok := delta["text"].(string)
|
|
if !ok {
|
|
return "", fmt.Errorf("invalid text delta")
|
|
}
|
|
sb.WriteString(text)
|
|
output <- text
|
|
case "content_block_stop":
|
|
// ignore?
|
|
case "message_delta":
|
|
delta, ok := event["delta"].(map[string]interface{})
|
|
if !ok {
|
|
return "", fmt.Errorf("invalid message delta")
|
|
}
|
|
stopReason, ok := delta["stop_reason"].(string)
|
|
if ok && stopReason == "stop_sequence" {
|
|
stopSequence, ok := delta["stop_sequence"].(string)
|
|
if ok && stopSequence == FUNCTION_STOP_SEQUENCE {
|
|
content := sb.String()
|
|
|
|
start := strings.Index(content, "<function_calls>")
|
|
if start == -1 {
|
|
return content, fmt.Errorf("reached </function_calls> stop sequence but no opening tag found")
|
|
}
|
|
|
|
sb.WriteString(FUNCTION_STOP_SEQUENCE)
|
|
output <- FUNCTION_STOP_SEQUENCE
|
|
|
|
funcCallXml := content[start:] + FUNCTION_STOP_SEQUENCE
|
|
|
|
var functionCalls XMLFunctionCalls
|
|
err := xml.Unmarshal([]byte(funcCallXml), &functionCalls)
|
|
if err != nil {
|
|
return "", fmt.Errorf("failed to unmarshal function_calls: %v", err)
|
|
}
|
|
|
|
toolCall := model.Message{
|
|
Role: model.MessageRoleToolCall,
|
|
// function call xml stripped from content for model interop
|
|
Content: strings.TrimSpace(content[:start]),
|
|
ToolCalls: convertXMLFunctionCallsToToolCalls(functionCalls),
|
|
}
|
|
|
|
toolResults, err := tools.ExecuteToolCalls(toolCall.ToolCalls, params.ToolBag)
|
|
if err != nil {
|
|
return "", err
|
|
}
|
|
|
|
toolResult := model.Message{
|
|
Role: model.MessageRoleToolResult,
|
|
ToolResults: toolResults,
|
|
}
|
|
|
|
if callback != nil {
|
|
callback(toolCall)
|
|
callback(toolResult)
|
|
}
|
|
|
|
if continuation {
|
|
messages[len(messages)-1] = toolCall
|
|
} else {
|
|
messages = append(messages, toolCall)
|
|
}
|
|
|
|
messages = append(messages, toolResult)
|
|
return c.CreateChatCompletionStream(ctx, params, messages, callback, output)
|
|
}
|
|
}
|
|
case "message_stop":
|
|
// return the completed message
|
|
content := sb.String()
|
|
if callback != nil {
|
|
callback(model.Message{
|
|
Role: model.MessageRoleAssistant,
|
|
Content: content,
|
|
})
|
|
}
|
|
return content, nil
|
|
case "error":
|
|
return sb.String(), fmt.Errorf("an error occurred: %s", event["error"])
|
|
default:
|
|
fmt.Printf("\nUnrecognized event: %s\n", data)
|
|
}
|
|
}
|
|
}
|
|
|
|
if err := scanner.Err(); err != nil {
|
|
return "", fmt.Errorf("failed to read response body: %v", err)
|
|
}
|
|
|
|
return "", fmt.Errorf("unexpected end of stream")
|
|
}
|