2024-02-21 21:55:38 -07:00
|
|
|
package anthropic
|
|
|
|
|
|
|
|
import (
|
|
|
|
"bufio"
|
|
|
|
"bytes"
|
2024-03-12 12:24:05 -06:00
|
|
|
"context"
|
2024-02-21 21:55:38 -07:00
|
|
|
"encoding/json"
|
|
|
|
"encoding/xml"
|
|
|
|
"fmt"
|
|
|
|
"net/http"
|
|
|
|
"strings"
|
|
|
|
|
2024-06-09 10:42:53 -06:00
|
|
|
"git.mlow.ca/mlow/lmcli/pkg/api"
|
2024-02-21 21:55:38 -07:00
|
|
|
"git.mlow.ca/mlow/lmcli/pkg/lmcli/model"
|
2024-03-12 12:24:05 -06:00
|
|
|
"git.mlow.ca/mlow/lmcli/pkg/lmcli/tools"
|
2024-02-21 21:55:38 -07:00
|
|
|
)
|
|
|
|
|
|
|
|
func buildRequest(params model.RequestParameters, messages []model.Message) Request {
|
|
|
|
requestBody := Request{
|
|
|
|
Model: params.Model,
|
|
|
|
Messages: make([]Message, len(messages)),
|
|
|
|
MaxTokens: params.MaxTokens,
|
|
|
|
Temperature: params.Temperature,
|
|
|
|
Stream: false,
|
|
|
|
|
|
|
|
StopSequences: []string{
|
|
|
|
FUNCTION_STOP_SEQUENCE,
|
|
|
|
"\n\nHuman:",
|
|
|
|
},
|
|
|
|
}
|
|
|
|
|
|
|
|
startIdx := 0
|
2024-03-12 14:54:02 -06:00
|
|
|
if len(messages) > 0 && messages[0].Role == model.MessageRoleSystem {
|
2024-02-21 21:55:38 -07:00
|
|
|
requestBody.System = messages[0].Content
|
2024-03-17 15:59:42 -06:00
|
|
|
requestBody.Messages = requestBody.Messages[1:]
|
2024-02-21 21:55:38 -07:00
|
|
|
startIdx = 1
|
|
|
|
}
|
|
|
|
|
|
|
|
if len(params.ToolBag) > 0 {
|
|
|
|
if len(requestBody.System) > 0 {
|
|
|
|
// add a divider between existing system prompt and tools
|
|
|
|
requestBody.System += "\n\n---\n\n"
|
|
|
|
}
|
|
|
|
requestBody.System += buildToolsSystemPrompt(params.ToolBag)
|
|
|
|
}
|
|
|
|
|
|
|
|
for i, msg := range messages[startIdx:] {
|
|
|
|
message := &requestBody.Messages[i]
|
|
|
|
|
|
|
|
switch msg.Role {
|
|
|
|
case model.MessageRoleToolCall:
|
|
|
|
message.Role = "assistant"
|
2024-03-12 14:54:02 -06:00
|
|
|
if msg.Content != "" {
|
|
|
|
message.Content = msg.Content
|
|
|
|
}
|
|
|
|
xmlFuncCalls := convertToolCallsToXMLFunctionCalls(msg.ToolCalls)
|
|
|
|
xmlString, err := xmlFuncCalls.XMLString()
|
|
|
|
if err != nil {
|
|
|
|
panic("Could not serialize []ToolCall to XMLFunctionCall")
|
|
|
|
}
|
2024-03-17 12:26:32 -06:00
|
|
|
if len(message.Content) > 0 {
|
|
|
|
message.Content += fmt.Sprintf("\n\n%s", xmlString)
|
|
|
|
} else {
|
|
|
|
message.Content = xmlString
|
|
|
|
}
|
2024-02-21 21:55:38 -07:00
|
|
|
case model.MessageRoleToolResult:
|
|
|
|
xmlFuncResults := convertToolResultsToXMLFunctionResult(msg.ToolResults)
|
|
|
|
xmlString, err := xmlFuncResults.XMLString()
|
|
|
|
if err != nil {
|
|
|
|
panic("Could not serialize []ToolResult to XMLFunctionResults")
|
|
|
|
}
|
|
|
|
message.Role = "user"
|
2024-03-12 14:54:02 -06:00
|
|
|
message.Content = xmlString
|
2024-02-21 21:55:38 -07:00
|
|
|
default:
|
|
|
|
message.Role = string(msg.Role)
|
2024-03-12 14:54:02 -06:00
|
|
|
message.Content = msg.Content
|
2024-02-21 21:55:38 -07:00
|
|
|
}
|
|
|
|
}
|
|
|
|
return requestBody
|
|
|
|
}
|
|
|
|
|
2024-03-12 12:24:05 -06:00
|
|
|
func sendRequest(ctx context.Context, c *AnthropicClient, r Request) (*http.Response, error) {
|
2024-02-21 21:55:38 -07:00
|
|
|
jsonBody, err := json.Marshal(r)
|
|
|
|
if err != nil {
|
|
|
|
return nil, fmt.Errorf("failed to marshal request body: %v", err)
|
|
|
|
}
|
|
|
|
|
2024-05-05 02:08:17 -06:00
|
|
|
req, err := http.NewRequestWithContext(ctx, "POST", c.BaseURL+"/messages", bytes.NewBuffer(jsonBody))
|
2024-02-21 21:55:38 -07:00
|
|
|
if err != nil {
|
|
|
|
return nil, fmt.Errorf("failed to create HTTP request: %v", err)
|
|
|
|
}
|
|
|
|
|
|
|
|
req.Header.Set("x-api-key", c.APIKey)
|
|
|
|
req.Header.Set("anthropic-version", "2023-06-01")
|
|
|
|
req.Header.Set("content-type", "application/json")
|
|
|
|
|
|
|
|
client := &http.Client{}
|
|
|
|
resp, err := client.Do(req)
|
|
|
|
if err != nil {
|
|
|
|
return nil, fmt.Errorf("failed to send HTTP request: %v", err)
|
|
|
|
}
|
|
|
|
|
|
|
|
return resp, nil
|
|
|
|
}
|
|
|
|
|
|
|
|
func (c *AnthropicClient) CreateChatCompletion(
|
2024-03-12 12:24:05 -06:00
|
|
|
ctx context.Context,
|
2024-02-21 21:55:38 -07:00
|
|
|
params model.RequestParameters,
|
|
|
|
messages []model.Message,
|
2024-06-09 10:42:53 -06:00
|
|
|
callback api.ReplyCallback,
|
2024-02-21 21:55:38 -07:00
|
|
|
) (string, error) {
|
2024-03-22 11:51:01 -06:00
|
|
|
if len(messages) == 0 {
|
|
|
|
return "", fmt.Errorf("Can't create completion from no messages")
|
|
|
|
}
|
|
|
|
|
2024-02-21 21:55:38 -07:00
|
|
|
request := buildRequest(params, messages)
|
|
|
|
|
2024-03-12 12:24:05 -06:00
|
|
|
resp, err := sendRequest(ctx, c, request)
|
2024-02-21 21:55:38 -07:00
|
|
|
if err != nil {
|
|
|
|
return "", err
|
|
|
|
}
|
|
|
|
defer resp.Body.Close()
|
|
|
|
|
|
|
|
var response Response
|
|
|
|
err = json.NewDecoder(resp.Body).Decode(&response)
|
|
|
|
if err != nil {
|
|
|
|
return "", fmt.Errorf("failed to decode response: %v", err)
|
|
|
|
}
|
|
|
|
|
|
|
|
sb := strings.Builder{}
|
2024-03-22 11:51:01 -06:00
|
|
|
|
|
|
|
lastMessage := messages[len(messages)-1]
|
|
|
|
if lastMessage.Role.IsAssistant() {
|
|
|
|
// this is a continuation of a previous assistant reply, so we'll
|
|
|
|
// include its contents in the final result
|
|
|
|
sb.WriteString(lastMessage.Content)
|
|
|
|
}
|
|
|
|
|
2024-03-12 14:54:02 -06:00
|
|
|
for _, content := range response.Content {
|
2024-02-21 21:55:38 -07:00
|
|
|
var reply model.Message
|
|
|
|
switch content.Type {
|
|
|
|
case "text":
|
|
|
|
reply = model.Message{
|
|
|
|
Role: model.MessageRoleAssistant,
|
|
|
|
Content: content.Text,
|
|
|
|
}
|
|
|
|
sb.WriteString(reply.Content)
|
|
|
|
default:
|
|
|
|
return "", fmt.Errorf("unsupported message type: %s", content.Type)
|
|
|
|
}
|
2024-03-12 14:36:24 -06:00
|
|
|
if callback != nil {
|
|
|
|
callback(reply)
|
|
|
|
}
|
2024-02-21 21:55:38 -07:00
|
|
|
}
|
|
|
|
|
|
|
|
return sb.String(), nil
|
|
|
|
}
|
|
|
|
|
|
|
|
func (c *AnthropicClient) CreateChatCompletionStream(
|
2024-03-12 12:24:05 -06:00
|
|
|
ctx context.Context,
|
2024-02-21 21:55:38 -07:00
|
|
|
params model.RequestParameters,
|
|
|
|
messages []model.Message,
|
2024-06-09 10:42:53 -06:00
|
|
|
callback api.ReplyCallback,
|
|
|
|
output chan<- api.Chunk,
|
2024-02-21 21:55:38 -07:00
|
|
|
) (string, error) {
|
2024-03-22 11:51:01 -06:00
|
|
|
if len(messages) == 0 {
|
|
|
|
return "", fmt.Errorf("Can't create completion from no messages")
|
|
|
|
}
|
|
|
|
|
2024-02-21 21:55:38 -07:00
|
|
|
request := buildRequest(params, messages)
|
|
|
|
request.Stream = true
|
|
|
|
|
2024-03-12 12:24:05 -06:00
|
|
|
resp, err := sendRequest(ctx, c, request)
|
2024-02-21 21:55:38 -07:00
|
|
|
if err != nil {
|
|
|
|
return "", err
|
|
|
|
}
|
|
|
|
defer resp.Body.Close()
|
|
|
|
|
|
|
|
sb := strings.Builder{}
|
|
|
|
|
2024-03-22 11:51:01 -06:00
|
|
|
lastMessage := messages[len(messages)-1]
|
|
|
|
continuation := false
|
|
|
|
if messages[len(messages)-1].Role.IsAssistant() {
|
|
|
|
// this is a continuation of a previous assistant reply, so we'll
|
|
|
|
// include its contents in the final result
|
|
|
|
sb.WriteString(lastMessage.Content)
|
|
|
|
continuation = true
|
|
|
|
}
|
2024-03-16 19:07:52 -06:00
|
|
|
|
2024-03-22 11:51:01 -06:00
|
|
|
scanner := bufio.NewScanner(resp.Body)
|
2024-02-21 21:55:38 -07:00
|
|
|
for scanner.Scan() {
|
|
|
|
line := scanner.Text()
|
|
|
|
line = strings.TrimSpace(line)
|
|
|
|
|
|
|
|
if len(line) == 0 {
|
|
|
|
continue
|
|
|
|
}
|
|
|
|
|
|
|
|
if line[0] == '{' {
|
|
|
|
var event map[string]interface{}
|
|
|
|
err := json.Unmarshal([]byte(line), &event)
|
|
|
|
if err != nil {
|
|
|
|
return "", fmt.Errorf("failed to unmarshal event data '%s': %v", line, err)
|
|
|
|
}
|
|
|
|
eventType, ok := event["type"].(string)
|
|
|
|
if !ok {
|
|
|
|
return "", fmt.Errorf("invalid event: %s", line)
|
|
|
|
}
|
|
|
|
switch eventType {
|
|
|
|
case "error":
|
|
|
|
return sb.String(), fmt.Errorf("an error occurred: %s", event["error"])
|
|
|
|
default:
|
|
|
|
return sb.String(), fmt.Errorf("unknown event type: %s", eventType)
|
|
|
|
}
|
|
|
|
} else if strings.HasPrefix(line, "data:") {
|
|
|
|
data := strings.TrimSpace(strings.TrimPrefix(line, "data:"))
|
|
|
|
var event map[string]interface{}
|
|
|
|
err := json.Unmarshal([]byte(data), &event)
|
|
|
|
if err != nil {
|
|
|
|
return "", fmt.Errorf("failed to unmarshal event data: %v", err)
|
|
|
|
}
|
|
|
|
|
|
|
|
eventType, ok := event["type"].(string)
|
|
|
|
if !ok {
|
|
|
|
return "", fmt.Errorf("invalid event type")
|
|
|
|
}
|
|
|
|
|
|
|
|
switch eventType {
|
|
|
|
case "message_start":
|
|
|
|
// noop
|
|
|
|
case "ping":
|
2024-05-30 12:52:23 -06:00
|
|
|
// signals start of text - currently ignoring
|
2024-02-21 21:55:38 -07:00
|
|
|
case "content_block_start":
|
|
|
|
// ignore?
|
|
|
|
case "content_block_delta":
|
|
|
|
delta, ok := event["delta"].(map[string]interface{})
|
|
|
|
if !ok {
|
|
|
|
return "", fmt.Errorf("invalid content block delta")
|
|
|
|
}
|
|
|
|
text, ok := delta["text"].(string)
|
|
|
|
if !ok {
|
|
|
|
return "", fmt.Errorf("invalid text delta")
|
|
|
|
}
|
|
|
|
sb.WriteString(text)
|
2024-06-09 10:42:53 -06:00
|
|
|
output <- api.Chunk{
|
2024-06-08 17:37:58 -06:00
|
|
|
Content: text,
|
|
|
|
}
|
2024-02-21 21:55:38 -07:00
|
|
|
case "content_block_stop":
|
|
|
|
// ignore?
|
|
|
|
case "message_delta":
|
|
|
|
delta, ok := event["delta"].(map[string]interface{})
|
|
|
|
if !ok {
|
|
|
|
return "", fmt.Errorf("invalid message delta")
|
|
|
|
}
|
|
|
|
stopReason, ok := delta["stop_reason"].(string)
|
|
|
|
if ok && stopReason == "stop_sequence" {
|
|
|
|
stopSequence, ok := delta["stop_sequence"].(string)
|
|
|
|
if ok && stopSequence == FUNCTION_STOP_SEQUENCE {
|
|
|
|
content := sb.String()
|
|
|
|
|
|
|
|
start := strings.Index(content, "<function_calls>")
|
|
|
|
if start == -1 {
|
|
|
|
return content, fmt.Errorf("reached </function_calls> stop sequence but no opening tag found")
|
|
|
|
}
|
|
|
|
|
|
|
|
sb.WriteString(FUNCTION_STOP_SEQUENCE)
|
2024-06-09 10:42:53 -06:00
|
|
|
output <- api.Chunk{
|
2024-06-08 17:37:58 -06:00
|
|
|
Content: FUNCTION_STOP_SEQUENCE,
|
|
|
|
}
|
2024-02-21 21:55:38 -07:00
|
|
|
|
2024-03-22 11:51:01 -06:00
|
|
|
funcCallXml := content[start:] + FUNCTION_STOP_SEQUENCE
|
|
|
|
|
2024-02-21 21:55:38 -07:00
|
|
|
var functionCalls XMLFunctionCalls
|
2024-03-22 11:51:01 -06:00
|
|
|
err := xml.Unmarshal([]byte(funcCallXml), &functionCalls)
|
2024-02-21 21:55:38 -07:00
|
|
|
if err != nil {
|
|
|
|
return "", fmt.Errorf("failed to unmarshal function_calls: %v", err)
|
|
|
|
}
|
|
|
|
|
|
|
|
toolCall := model.Message{
|
2024-03-12 14:54:02 -06:00
|
|
|
Role: model.MessageRoleToolCall,
|
2024-03-22 11:51:01 -06:00
|
|
|
// function call xml stripped from content for model interop
|
|
|
|
Content: strings.TrimSpace(content[:start]),
|
2024-02-21 21:55:38 -07:00
|
|
|
ToolCalls: convertXMLFunctionCallsToToolCalls(functionCalls),
|
|
|
|
}
|
|
|
|
|
|
|
|
toolResults, err := tools.ExecuteToolCalls(toolCall.ToolCalls, params.ToolBag)
|
|
|
|
if err != nil {
|
|
|
|
return "", err
|
|
|
|
}
|
|
|
|
|
2024-03-22 11:51:01 -06:00
|
|
|
toolResult := model.Message{
|
2024-02-21 21:55:38 -07:00
|
|
|
Role: model.MessageRoleToolResult,
|
|
|
|
ToolResults: toolResults,
|
|
|
|
}
|
|
|
|
|
2024-03-12 14:36:24 -06:00
|
|
|
if callback != nil {
|
|
|
|
callback(toolCall)
|
2024-03-22 11:51:01 -06:00
|
|
|
callback(toolResult)
|
|
|
|
}
|
|
|
|
|
|
|
|
if continuation {
|
|
|
|
messages[len(messages)-1] = toolCall
|
|
|
|
} else {
|
|
|
|
messages = append(messages, toolCall)
|
2024-02-21 21:55:38 -07:00
|
|
|
}
|
|
|
|
|
2024-03-22 11:51:01 -06:00
|
|
|
messages = append(messages, toolResult)
|
2024-03-12 14:36:24 -06:00
|
|
|
return c.CreateChatCompletionStream(ctx, params, messages, callback, output)
|
2024-02-21 21:55:38 -07:00
|
|
|
}
|
|
|
|
}
|
|
|
|
case "message_stop":
|
|
|
|
// return the completed message
|
2024-03-22 11:51:01 -06:00
|
|
|
content := sb.String()
|
2024-03-12 14:36:24 -06:00
|
|
|
if callback != nil {
|
2024-03-22 11:51:01 -06:00
|
|
|
callback(model.Message{
|
|
|
|
Role: model.MessageRoleAssistant,
|
|
|
|
Content: content,
|
|
|
|
})
|
2024-02-21 21:55:38 -07:00
|
|
|
}
|
2024-03-22 11:51:01 -06:00
|
|
|
return content, nil
|
2024-02-21 21:55:38 -07:00
|
|
|
case "error":
|
|
|
|
return sb.String(), fmt.Errorf("an error occurred: %s", event["error"])
|
|
|
|
default:
|
|
|
|
fmt.Printf("\nUnrecognized event: %s\n", data)
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
if err := scanner.Err(); err != nil {
|
|
|
|
return "", fmt.Errorf("failed to read response body: %v", err)
|
|
|
|
}
|
|
|
|
|
|
|
|
return "", fmt.Errorf("unexpected end of stream")
|
|
|
|
}
|