Private
Public Access
1
0
Files
lmcli/pkg/lmcli/provider/anthropic/anthropic.go
Matt Low 46149e0b67 Attempt to fix anthropic tool calling
Models have been way too eager to use tools when the task does not
require it (for example, reading the filesystem in order to show an
code example)
2024-03-17 22:55:02 +00:00

349 lines
9.0 KiB
Go

package anthropic
import (
"bufio"
"bytes"
"context"
"encoding/json"
"encoding/xml"
"fmt"
"net/http"
"strings"
"git.mlow.ca/mlow/lmcli/pkg/lmcli/model"
"git.mlow.ca/mlow/lmcli/pkg/lmcli/provider"
"git.mlow.ca/mlow/lmcli/pkg/lmcli/tools"
)
type AnthropicClient struct {
APIKey string
}
type Message struct {
Role string `json:"role"`
Content string `json:"content"`
}
type Request struct {
Model string `json:"model"`
Messages []Message `json:"messages"`
System string `json:"system,omitempty"`
MaxTokens int `json:"max_tokens,omitempty"`
StopSequences []string `json:"stop_sequences,omitempty"`
Stream bool `json:"stream,omitempty"`
Temperature float32 `json:"temperature,omitempty"`
//TopP float32 `json:"top_p,omitempty"`
//TopK float32 `json:"top_k,omitempty"`
}
type OriginalContent struct {
Type string `json:"type"`
Text string `json:"text"`
}
type Response struct {
Id string `json:"id"`
Type string `json:"type"`
Role string `json:"role"`
Content []OriginalContent `json:"content"`
}
const FUNCTION_STOP_SEQUENCE = "</function_calls>"
func buildRequest(params model.RequestParameters, messages []model.Message) Request {
requestBody := Request{
Model: params.Model,
Messages: make([]Message, len(messages)),
System: params.SystemPrompt,
MaxTokens: params.MaxTokens,
Temperature: params.Temperature,
Stream: false,
StopSequences: []string{
FUNCTION_STOP_SEQUENCE,
"\n\nHuman:",
},
}
startIdx := 0
if len(messages) > 0 && messages[0].Role == model.MessageRoleSystem {
requestBody.System = messages[0].Content
requestBody.Messages = requestBody.Messages[1:]
startIdx = 1
}
if len(params.ToolBag) > 0 {
if len(requestBody.System) > 0 {
// add a divider between existing system prompt and tools
requestBody.System += "\n\n---\n\n"
}
requestBody.System += buildToolsSystemPrompt(params.ToolBag)
}
for i, msg := range messages[startIdx:] {
message := &requestBody.Messages[i]
switch msg.Role {
case model.MessageRoleToolCall:
message.Role = "assistant"
if msg.Content != "" {
message.Content = msg.Content
}
xmlFuncCalls := convertToolCallsToXMLFunctionCalls(msg.ToolCalls)
xmlString, err := xmlFuncCalls.XMLString()
if err != nil {
panic("Could not serialize []ToolCall to XMLFunctionCall")
}
if len(message.Content) > 0 {
message.Content += fmt.Sprintf("\n\n%s", xmlString)
} else {
message.Content = xmlString
}
case model.MessageRoleToolResult:
xmlFuncResults := convertToolResultsToXMLFunctionResult(msg.ToolResults)
xmlString, err := xmlFuncResults.XMLString()
if err != nil {
panic("Could not serialize []ToolResult to XMLFunctionResults")
}
message.Role = "user"
message.Content = xmlString
default:
message.Role = string(msg.Role)
message.Content = msg.Content
}
}
return requestBody
}
func sendRequest(ctx context.Context, c *AnthropicClient, r Request) (*http.Response, error) {
url := "https://api.anthropic.com/v1/messages"
jsonBody, err := json.Marshal(r)
if err != nil {
return nil, fmt.Errorf("failed to marshal request body: %v", err)
}
req, err := http.NewRequestWithContext(ctx, "POST", url, bytes.NewBuffer(jsonBody))
if err != nil {
return nil, fmt.Errorf("failed to create HTTP request: %v", err)
}
req.Header.Set("x-api-key", c.APIKey)
req.Header.Set("anthropic-version", "2023-06-01")
req.Header.Set("content-type", "application/json")
client := &http.Client{}
resp, err := client.Do(req)
if err != nil {
return nil, fmt.Errorf("failed to send HTTP request: %v", err)
}
return resp, nil
}
func (c *AnthropicClient) CreateChatCompletion(
ctx context.Context,
params model.RequestParameters,
messages []model.Message,
callback provider.ReplyCallback,
) (string, error) {
request := buildRequest(params, messages)
resp, err := sendRequest(ctx, c, request)
if err != nil {
return "", err
}
defer resp.Body.Close()
var response Response
err = json.NewDecoder(resp.Body).Decode(&response)
if err != nil {
return "", fmt.Errorf("failed to decode response: %v", err)
}
sb := strings.Builder{}
for _, content := range response.Content {
var reply model.Message
switch content.Type {
case "text":
reply = model.Message{
Role: model.MessageRoleAssistant,
Content: content.Text,
}
sb.WriteString(reply.Content)
default:
return "", fmt.Errorf("unsupported message type: %s", content.Type)
}
if callback != nil {
callback(reply)
}
}
return sb.String(), nil
}
func (c *AnthropicClient) CreateChatCompletionStream(
ctx context.Context,
params model.RequestParameters,
messages []model.Message,
callback provider.ReplyCallback,
output chan<- string,
) (string, error) {
request := buildRequest(params, messages)
request.Stream = true
resp, err := sendRequest(ctx, c, request)
if err != nil {
return "", err
}
defer resp.Body.Close()
scanner := bufio.NewScanner(resp.Body)
sb := strings.Builder{}
isToolCall := false
for scanner.Scan() {
line := scanner.Text()
line = strings.TrimSpace(line)
if len(line) == 0 {
continue
}
if line[0] == '{' {
var event map[string]interface{}
err := json.Unmarshal([]byte(line), &event)
if err != nil {
return "", fmt.Errorf("failed to unmarshal event data '%s': %v", line, err)
}
eventType, ok := event["type"].(string)
if !ok {
return "", fmt.Errorf("invalid event: %s", line)
}
switch eventType {
case "error":
return sb.String(), fmt.Errorf("an error occurred: %s", event["error"])
default:
return sb.String(), fmt.Errorf("unknown event type: %s", eventType)
}
} else if strings.HasPrefix(line, "data:") {
data := strings.TrimSpace(strings.TrimPrefix(line, "data:"))
var event map[string]interface{}
err := json.Unmarshal([]byte(data), &event)
if err != nil {
return "", fmt.Errorf("failed to unmarshal event data: %v", err)
}
eventType, ok := event["type"].(string)
if !ok {
return "", fmt.Errorf("invalid event type")
}
switch eventType {
case "message_start":
// noop
case "ping":
// write an empty string to signal start of text
output <- ""
case "content_block_start":
// ignore?
case "content_block_delta":
delta, ok := event["delta"].(map[string]interface{})
if !ok {
return "", fmt.Errorf("invalid content block delta")
}
text, ok := delta["text"].(string)
if !ok {
return "", fmt.Errorf("invalid text delta")
}
sb.WriteString(text)
output <- text
case "content_block_stop":
// ignore?
case "message_delta":
delta, ok := event["delta"].(map[string]interface{})
if !ok {
return "", fmt.Errorf("invalid message delta")
}
stopReason, ok := delta["stop_reason"].(string)
if ok && stopReason == "stop_sequence" {
stopSequence, ok := delta["stop_sequence"].(string)
if ok && stopSequence == FUNCTION_STOP_SEQUENCE {
content := sb.String()
start := strings.Index(content, "<function_calls>")
if start == -1 {
return content, fmt.Errorf("reached </function_calls> stop sequence but no opening tag found")
}
isToolCall = true
funcCallXml := content[start:]
funcCallXml += FUNCTION_STOP_SEQUENCE
sb.WriteString(FUNCTION_STOP_SEQUENCE)
output <- FUNCTION_STOP_SEQUENCE
// Extract function calls
var functionCalls XMLFunctionCalls
err := xml.Unmarshal([]byte(sb.String()), &functionCalls)
if err != nil {
return "", fmt.Errorf("failed to unmarshal function_calls: %v", err)
}
// Execute function calls
toolCall := model.Message{
Role: model.MessageRoleToolCall,
// xml stripped from content
Content: content[:start],
ToolCalls: convertXMLFunctionCallsToToolCalls(functionCalls),
}
toolResults, err := tools.ExecuteToolCalls(toolCall.ToolCalls, params.ToolBag)
if err != nil {
return "", err
}
toolReply := model.Message{
Role: model.MessageRoleToolResult,
ToolResults: toolResults,
}
if callback != nil {
callback(toolCall)
callback(toolReply)
}
// Recurse into CreateChatCompletionStream with the tool call replies
// added to the original messages
messages = append(append(messages, toolCall), toolReply)
return c.CreateChatCompletionStream(ctx, params, messages, callback, output)
}
}
case "message_stop":
// return the completed message
if callback != nil {
if !isToolCall {
callback(model.Message{
Role: model.MessageRoleAssistant,
Content: sb.String(),
})
}
}
return sb.String(), nil
case "error":
return sb.String(), fmt.Errorf("an error occurred: %s", event["error"])
default:
fmt.Printf("\nUnrecognized event: %s\n", data)
}
}
}
if err := scanner.Err(); err != nil {
return "", fmt.Errorf("failed to read response body: %v", err)
}
return "", fmt.Errorf("unexpected end of stream")
}