Models have been way too eager to use tools when the task does not require it (for example, reading the filesystem in order to show an code example)
349 lines
9.0 KiB
Go
349 lines
9.0 KiB
Go
package anthropic
|
|
|
|
import (
|
|
"bufio"
|
|
"bytes"
|
|
"context"
|
|
"encoding/json"
|
|
"encoding/xml"
|
|
"fmt"
|
|
"net/http"
|
|
"strings"
|
|
|
|
"git.mlow.ca/mlow/lmcli/pkg/lmcli/model"
|
|
"git.mlow.ca/mlow/lmcli/pkg/lmcli/provider"
|
|
"git.mlow.ca/mlow/lmcli/pkg/lmcli/tools"
|
|
)
|
|
|
|
type AnthropicClient struct {
|
|
APIKey string
|
|
}
|
|
|
|
type Message struct {
|
|
Role string `json:"role"`
|
|
Content string `json:"content"`
|
|
}
|
|
|
|
type Request struct {
|
|
Model string `json:"model"`
|
|
Messages []Message `json:"messages"`
|
|
System string `json:"system,omitempty"`
|
|
MaxTokens int `json:"max_tokens,omitempty"`
|
|
StopSequences []string `json:"stop_sequences,omitempty"`
|
|
Stream bool `json:"stream,omitempty"`
|
|
Temperature float32 `json:"temperature,omitempty"`
|
|
//TopP float32 `json:"top_p,omitempty"`
|
|
//TopK float32 `json:"top_k,omitempty"`
|
|
}
|
|
|
|
type OriginalContent struct {
|
|
Type string `json:"type"`
|
|
Text string `json:"text"`
|
|
}
|
|
|
|
type Response struct {
|
|
Id string `json:"id"`
|
|
Type string `json:"type"`
|
|
Role string `json:"role"`
|
|
Content []OriginalContent `json:"content"`
|
|
}
|
|
|
|
const FUNCTION_STOP_SEQUENCE = "</function_calls>"
|
|
|
|
func buildRequest(params model.RequestParameters, messages []model.Message) Request {
|
|
requestBody := Request{
|
|
Model: params.Model,
|
|
Messages: make([]Message, len(messages)),
|
|
System: params.SystemPrompt,
|
|
MaxTokens: params.MaxTokens,
|
|
Temperature: params.Temperature,
|
|
Stream: false,
|
|
|
|
StopSequences: []string{
|
|
FUNCTION_STOP_SEQUENCE,
|
|
"\n\nHuman:",
|
|
},
|
|
}
|
|
|
|
startIdx := 0
|
|
if len(messages) > 0 && messages[0].Role == model.MessageRoleSystem {
|
|
requestBody.System = messages[0].Content
|
|
requestBody.Messages = requestBody.Messages[1:]
|
|
startIdx = 1
|
|
}
|
|
|
|
if len(params.ToolBag) > 0 {
|
|
if len(requestBody.System) > 0 {
|
|
// add a divider between existing system prompt and tools
|
|
requestBody.System += "\n\n---\n\n"
|
|
}
|
|
requestBody.System += buildToolsSystemPrompt(params.ToolBag)
|
|
}
|
|
|
|
for i, msg := range messages[startIdx:] {
|
|
message := &requestBody.Messages[i]
|
|
|
|
switch msg.Role {
|
|
case model.MessageRoleToolCall:
|
|
message.Role = "assistant"
|
|
if msg.Content != "" {
|
|
message.Content = msg.Content
|
|
}
|
|
xmlFuncCalls := convertToolCallsToXMLFunctionCalls(msg.ToolCalls)
|
|
xmlString, err := xmlFuncCalls.XMLString()
|
|
if err != nil {
|
|
panic("Could not serialize []ToolCall to XMLFunctionCall")
|
|
}
|
|
if len(message.Content) > 0 {
|
|
message.Content += fmt.Sprintf("\n\n%s", xmlString)
|
|
} else {
|
|
message.Content = xmlString
|
|
}
|
|
case model.MessageRoleToolResult:
|
|
xmlFuncResults := convertToolResultsToXMLFunctionResult(msg.ToolResults)
|
|
xmlString, err := xmlFuncResults.XMLString()
|
|
if err != nil {
|
|
panic("Could not serialize []ToolResult to XMLFunctionResults")
|
|
}
|
|
message.Role = "user"
|
|
message.Content = xmlString
|
|
default:
|
|
message.Role = string(msg.Role)
|
|
message.Content = msg.Content
|
|
}
|
|
}
|
|
return requestBody
|
|
}
|
|
|
|
func sendRequest(ctx context.Context, c *AnthropicClient, r Request) (*http.Response, error) {
|
|
url := "https://api.anthropic.com/v1/messages"
|
|
|
|
jsonBody, err := json.Marshal(r)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to marshal request body: %v", err)
|
|
}
|
|
|
|
req, err := http.NewRequestWithContext(ctx, "POST", url, bytes.NewBuffer(jsonBody))
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to create HTTP request: %v", err)
|
|
}
|
|
|
|
req.Header.Set("x-api-key", c.APIKey)
|
|
req.Header.Set("anthropic-version", "2023-06-01")
|
|
req.Header.Set("content-type", "application/json")
|
|
|
|
client := &http.Client{}
|
|
resp, err := client.Do(req)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to send HTTP request: %v", err)
|
|
}
|
|
|
|
return resp, nil
|
|
}
|
|
|
|
func (c *AnthropicClient) CreateChatCompletion(
|
|
ctx context.Context,
|
|
params model.RequestParameters,
|
|
messages []model.Message,
|
|
callback provider.ReplyCallback,
|
|
) (string, error) {
|
|
request := buildRequest(params, messages)
|
|
|
|
resp, err := sendRequest(ctx, c, request)
|
|
if err != nil {
|
|
return "", err
|
|
}
|
|
defer resp.Body.Close()
|
|
|
|
var response Response
|
|
err = json.NewDecoder(resp.Body).Decode(&response)
|
|
if err != nil {
|
|
return "", fmt.Errorf("failed to decode response: %v", err)
|
|
}
|
|
|
|
sb := strings.Builder{}
|
|
for _, content := range response.Content {
|
|
var reply model.Message
|
|
switch content.Type {
|
|
case "text":
|
|
reply = model.Message{
|
|
Role: model.MessageRoleAssistant,
|
|
Content: content.Text,
|
|
}
|
|
sb.WriteString(reply.Content)
|
|
default:
|
|
return "", fmt.Errorf("unsupported message type: %s", content.Type)
|
|
}
|
|
if callback != nil {
|
|
callback(reply)
|
|
}
|
|
}
|
|
|
|
return sb.String(), nil
|
|
}
|
|
|
|
func (c *AnthropicClient) CreateChatCompletionStream(
|
|
ctx context.Context,
|
|
params model.RequestParameters,
|
|
messages []model.Message,
|
|
callback provider.ReplyCallback,
|
|
output chan<- string,
|
|
) (string, error) {
|
|
request := buildRequest(params, messages)
|
|
request.Stream = true
|
|
|
|
resp, err := sendRequest(ctx, c, request)
|
|
if err != nil {
|
|
return "", err
|
|
}
|
|
defer resp.Body.Close()
|
|
|
|
scanner := bufio.NewScanner(resp.Body)
|
|
sb := strings.Builder{}
|
|
|
|
isToolCall := false
|
|
|
|
for scanner.Scan() {
|
|
line := scanner.Text()
|
|
line = strings.TrimSpace(line)
|
|
|
|
if len(line) == 0 {
|
|
continue
|
|
}
|
|
|
|
if line[0] == '{' {
|
|
var event map[string]interface{}
|
|
err := json.Unmarshal([]byte(line), &event)
|
|
if err != nil {
|
|
return "", fmt.Errorf("failed to unmarshal event data '%s': %v", line, err)
|
|
}
|
|
eventType, ok := event["type"].(string)
|
|
if !ok {
|
|
return "", fmt.Errorf("invalid event: %s", line)
|
|
}
|
|
switch eventType {
|
|
case "error":
|
|
return sb.String(), fmt.Errorf("an error occurred: %s", event["error"])
|
|
default:
|
|
return sb.String(), fmt.Errorf("unknown event type: %s", eventType)
|
|
}
|
|
} else if strings.HasPrefix(line, "data:") {
|
|
data := strings.TrimSpace(strings.TrimPrefix(line, "data:"))
|
|
var event map[string]interface{}
|
|
err := json.Unmarshal([]byte(data), &event)
|
|
if err != nil {
|
|
return "", fmt.Errorf("failed to unmarshal event data: %v", err)
|
|
}
|
|
|
|
eventType, ok := event["type"].(string)
|
|
if !ok {
|
|
return "", fmt.Errorf("invalid event type")
|
|
}
|
|
|
|
switch eventType {
|
|
case "message_start":
|
|
// noop
|
|
case "ping":
|
|
// write an empty string to signal start of text
|
|
output <- ""
|
|
case "content_block_start":
|
|
// ignore?
|
|
case "content_block_delta":
|
|
delta, ok := event["delta"].(map[string]interface{})
|
|
if !ok {
|
|
return "", fmt.Errorf("invalid content block delta")
|
|
}
|
|
text, ok := delta["text"].(string)
|
|
if !ok {
|
|
return "", fmt.Errorf("invalid text delta")
|
|
}
|
|
sb.WriteString(text)
|
|
output <- text
|
|
case "content_block_stop":
|
|
// ignore?
|
|
case "message_delta":
|
|
delta, ok := event["delta"].(map[string]interface{})
|
|
if !ok {
|
|
return "", fmt.Errorf("invalid message delta")
|
|
}
|
|
stopReason, ok := delta["stop_reason"].(string)
|
|
if ok && stopReason == "stop_sequence" {
|
|
stopSequence, ok := delta["stop_sequence"].(string)
|
|
if ok && stopSequence == FUNCTION_STOP_SEQUENCE {
|
|
content := sb.String()
|
|
|
|
start := strings.Index(content, "<function_calls>")
|
|
if start == -1 {
|
|
return content, fmt.Errorf("reached </function_calls> stop sequence but no opening tag found")
|
|
}
|
|
|
|
isToolCall = true
|
|
|
|
funcCallXml := content[start:]
|
|
funcCallXml += FUNCTION_STOP_SEQUENCE
|
|
|
|
sb.WriteString(FUNCTION_STOP_SEQUENCE)
|
|
output <- FUNCTION_STOP_SEQUENCE
|
|
|
|
// Extract function calls
|
|
var functionCalls XMLFunctionCalls
|
|
err := xml.Unmarshal([]byte(sb.String()), &functionCalls)
|
|
if err != nil {
|
|
return "", fmt.Errorf("failed to unmarshal function_calls: %v", err)
|
|
}
|
|
|
|
// Execute function calls
|
|
toolCall := model.Message{
|
|
Role: model.MessageRoleToolCall,
|
|
// xml stripped from content
|
|
Content: content[:start],
|
|
ToolCalls: convertXMLFunctionCallsToToolCalls(functionCalls),
|
|
}
|
|
|
|
toolResults, err := tools.ExecuteToolCalls(toolCall.ToolCalls, params.ToolBag)
|
|
if err != nil {
|
|
return "", err
|
|
}
|
|
|
|
toolReply := model.Message{
|
|
Role: model.MessageRoleToolResult,
|
|
ToolResults: toolResults,
|
|
}
|
|
|
|
if callback != nil {
|
|
callback(toolCall)
|
|
callback(toolReply)
|
|
}
|
|
|
|
// Recurse into CreateChatCompletionStream with the tool call replies
|
|
// added to the original messages
|
|
messages = append(append(messages, toolCall), toolReply)
|
|
return c.CreateChatCompletionStream(ctx, params, messages, callback, output)
|
|
}
|
|
}
|
|
case "message_stop":
|
|
// return the completed message
|
|
if callback != nil {
|
|
if !isToolCall {
|
|
callback(model.Message{
|
|
Role: model.MessageRoleAssistant,
|
|
Content: sb.String(),
|
|
})
|
|
}
|
|
}
|
|
return sb.String(), nil
|
|
case "error":
|
|
return sb.String(), fmt.Errorf("an error occurred: %s", event["error"])
|
|
default:
|
|
fmt.Printf("\nUnrecognized event: %s\n", data)
|
|
}
|
|
}
|
|
}
|
|
|
|
if err := scanner.Err(); err != nil {
|
|
return "", fmt.Errorf("failed to read response body: %v", err)
|
|
}
|
|
|
|
return "", fmt.Errorf("unexpected end of stream")
|
|
}
|