Support Anthropic's native tool calling API

This commit is contained in:
Matt Low 2024-06-23 01:46:27 +00:00
parent c50b6b154d
commit 94d84ba7d7
6 changed files with 357 additions and 468 deletions

View File

@ -3,7 +3,7 @@
- [x] Strip anthropic XML function call scheme from content, to reconstruct
when calling anthropic?
- [x] `dir_tree` tool
- [ ] Implement native Anthropic API tool calling
- [x] Implement native Anthropic API tool calling
- [ ] Agents - a name given to a system prompt + set of available tools +
potentially other relevent data (e.g. external service credentials, files for
RAG, etc), which the user explicitly selects (e.g. `lmcli chat --agent

View File

@ -40,3 +40,10 @@ type ChatCompletionProvider interface {
chunks chan<- Chunk,
) (*Message, error)
func IsAssistantContinuation(messages []Message) bool {
if len(messages) == 0 {
return false
return messages[len(messages)-1].Role == MessageRoleAssistant

View File

@ -5,100 +5,223 @@ import (
func buildRequest(params api.RequestParameters, messages []api.Message) Request {
requestBody := Request{
Model: params.Model,
Messages: make([]Message, len(messages)),
MaxTokens: params.MaxTokens,
Temperature: params.Temperature,
Stream: false,
const ANTHROPIC_VERSION = "2023-06-01"
StopSequences: []string{
type AnthropicClient struct {
APIKey string
BaseURL string
startIdx := 0
if len(messages) > 0 && messages[0].Role == api.MessageRoleSystem {
requestBody.System = messages[0].Content
requestBody.Messages = requestBody.Messages[1:]
startIdx = 1
type ChatCompletionMessage struct {
Role string `json:"role"`
Content interface{} `json:"content"`
type Tool struct {
Name string `json:"name"`
Description string `json:"description"`
InputSchema InputSchema `json:"input_schema"`
type InputSchema struct {
Type string `json:"type"`
Properties map[string]Property `json:"properties"`
Required []string `json:"required"`
type Property struct {
Type string `json:"type"`
Description string `json:"description"`
Enum []string `json:"enum,omitempty"`
type ChatCompletionRequest struct {
Model string `json:"model"`
Messages []ChatCompletionMessage `json:"messages"`
System string `json:"system,omitempty"`
Tools []Tool `json:"tools,omitempty"`
MaxTokens int `json:"max_tokens"`
Temperature float32 `json:"temperature,omitempty"`
Stream bool `json:"stream"`
type ContentBlock struct {
Type string `json:"type"`
Text string `json:"text,omitempty"`
ID string `json:"id,omitempty"`
Name string `json:"name,omitempty"`
Input interface{} `json:"input,omitempty"`
partialJsonAccumulator string
type ChatCompletionResponse struct {
ID string `json:"id"`
Type string `json:"type"`
Role string `json:"role"`
Model string `json:"model"`
Content []ContentBlock `json:"content"`
StopReason string `json:"stop_reason"`
Usage Usage `json:"usage"`
type Usage struct {
InputTokens int `json:"input_tokens"`
OutputTokens int `json:"output_tokens"`
type StreamEvent struct {
Type string `json:"type"`
Message interface{} `json:"message,omitempty"`
Index int `json:"index,omitempty"`
Delta interface{} `json:"delta,omitempty"`
func convertTools(tools []api.ToolSpec) []Tool {
anthropicTools := make([]Tool, len(tools))
for i, tool := range tools {
properties := make(map[string]Property)
for _, param := range tool.Parameters {
properties[param.Name] = Property{
Type: param.Type,
Description: param.Description,
Enum: param.Enum,
var required []string
for _, param := range tool.Parameters {
if param.Required {
required = append(required, param.Name)
anthropicTools[i] = Tool{
Name: tool.Name,
Description: tool.Description,
InputSchema: InputSchema{
Type: "object",
Properties: properties,
Required: required,
return anthropicTools
func createChatCompletionRequest(
params api.RequestParameters,
messages []api.Message,
) (string, ChatCompletionRequest) {
requestMessages := make([]ChatCompletionMessage, 0, len(messages))
var systemMessage string
for _, m := range messages {
if m.Role == api.MessageRoleSystem {
systemMessage = m.Content
var content interface{}
role := string(m.Role)
switch m.Role {
case api.MessageRoleToolCall:
role = "assistant"
contentBlocks := make([]map[string]interface{}, 0)
if m.Content != "" {
contentBlocks = append(contentBlocks, map[string]interface{}{
"type": "text",
"text": m.Content,
for _, toolCall := range m.ToolCalls {
contentBlocks = append(contentBlocks, map[string]interface{}{
"type": "tool_use",
"id": toolCall.ID,
"name": toolCall.Name,
"input": toolCall.Parameters,
content = contentBlocks
case api.MessageRoleToolResult:
role = "user"
contentBlocks := make([]map[string]interface{}, 0)
for _, result := range m.ToolResults {
contentBlock := map[string]interface{}{
"type": "tool_result",
"tool_use_id": result.ToolCallID,
"content": result.Result,
contentBlocks = append(contentBlocks, contentBlock)
content = contentBlocks
content = m.Content
requestMessages = append(requestMessages, ChatCompletionMessage{
Role: role,
Content: content,
request := ChatCompletionRequest{
Model: params.Model,
Messages: requestMessages,
System: systemMessage,
MaxTokens: params.MaxTokens,
Temperature: params.Temperature,
if len(params.ToolBag) > 0 {
if len(requestBody.System) > 0 {
// add a divider between existing system prompt and tools
requestBody.System += "\n\n---\n\n"
requestBody.System += buildToolsSystemPrompt(params.ToolBag)
request.Tools = convertTools(params.ToolBag)
for i, msg := range messages[startIdx:] {
message := &requestBody.Messages[i]
switch msg.Role {
case api.MessageRoleToolCall:
message.Role = "assistant"
if msg.Content != "" {
message.Content = msg.Content
xmlFuncCalls := convertToolCallsToXMLFunctionCalls(msg.ToolCalls)
xmlString, err := xmlFuncCalls.XMLString()
if err != nil {
panic("Could not serialize []ToolCall to XMLFunctionCall")
if len(message.Content) > 0 {
message.Content += fmt.Sprintf("\n\n%s", xmlString)
} else {
message.Content = xmlString
case api.MessageRoleToolResult:
xmlFuncResults := convertToolResultsToXMLFunctionResult(msg.ToolResults)
xmlString, err := xmlFuncResults.XMLString()
if err != nil {
panic("Could not serialize []ToolResult to XMLFunctionResults")
message.Role = "user"
message.Content = xmlString
message.Role = string(msg.Role)
message.Content = msg.Content
return requestBody
var prefill string
if api.IsAssistantContinuation(messages) {
prefill = messages[len(messages)-1].Content
func sendRequest(ctx context.Context, c *AnthropicClient, r Request) (*http.Response, error) {
jsonBody, err := json.Marshal(r)
if err != nil {
return nil, fmt.Errorf("failed to marshal request body: %v", err)
return prefill, request
req, err := http.NewRequestWithContext(ctx, "POST", c.BaseURL+"/messages", bytes.NewBuffer(jsonBody))
func (c *AnthropicClient) sendRequest(ctx context.Context, r ChatCompletionRequest) (*http.Response, error) {
jsonData, err := json.Marshal(r)
if err != nil {
return nil, fmt.Errorf("failed to create HTTP request: %v", err)
return nil, fmt.Errorf("failed to marshal request: %w", err)
req, err := http.NewRequestWithContext(ctx, "POST", c.BaseURL+"/v1/messages", bytes.NewBuffer(jsonData))
if err != nil {
return nil, fmt.Errorf("failed to create HTTP request: %w", err)
req.Header.Set("x-api-key", c.APIKey)
req.Header.Set("anthropic-version", "2023-06-01")
req.Header.Set("anthropic-version", ANTHROPIC_VERSION)
req.Header.Set("content-type", "application/json")
client := &http.Client{}
resp, err := client.Do(req)
if err != nil {
return nil, fmt.Errorf("failed to send HTTP request: %v", err)
return nil, err
return resp, nil
if resp.StatusCode != 200 {
bytes, _ := io.ReadAll(resp.Body)
return resp, fmt.Errorf("%v", string(bytes))
return resp, err
func (c *AnthropicClient) CreateChatCompletion(
@ -107,45 +230,25 @@ func (c *AnthropicClient) CreateChatCompletion(
messages []api.Message,
) (*api.Message, error) {
if len(messages) == 0 {
return nil, fmt.Errorf("Can't create completion from no messages")
return nil, fmt.Errorf("can't create completion from no messages")
request := buildRequest(params, messages)
_, req := createChatCompletionRequest(params, messages)
req.Stream = false
resp, err := sendRequest(ctx, c, request)
resp, err := c.sendRequest(ctx, req)
if err != nil {
return nil, err
return nil, fmt.Errorf("failed to send request: %w", err)
defer resp.Body.Close()
var response Response
err = json.NewDecoder(resp.Body).Decode(&response)
var completionResp ChatCompletionResponse
err = json.NewDecoder(resp.Body).Decode(&completionResp)
if err != nil {
return nil, fmt.Errorf("failed to decode response: %v", err)
return nil, fmt.Errorf("failed to decode response: %w", err)
sb := strings.Builder{}
lastMessage := messages[len(messages)-1]
if lastMessage.Role.IsAssistant() {
// this is a continuation of a previous assistant reply, so we'll
// include its contents in the final result
for _, content := range response.Content {
switch content.Type {
case "text":
return nil, fmt.Errorf("unsupported message type: %s", content.Type)
return &api.Message{
Role: api.MessageRoleAssistant,
Content: sb.String(),
}, nil
return convertResponseToMessage(completionResp)
func (c *AnthropicClient) CreateChatCompletionStream(
@ -155,144 +258,193 @@ func (c *AnthropicClient) CreateChatCompletionStream(
output chan<- api.Chunk,
) (*api.Message, error) {
if len(messages) == 0 {
return nil, fmt.Errorf("Can't create completion from no messages")
return nil, fmt.Errorf("can't create completion from no messages")
request := buildRequest(params, messages)
request.Stream = true
prefill, req := createChatCompletionRequest(params, messages)
req.Stream = true
resp, err := sendRequest(ctx, c, request)
resp, err := c.sendRequest(ctx, req)
if err != nil {
return nil, err
return nil, fmt.Errorf("failed to send request: %w", err)
defer resp.Body.Close()
sb := strings.Builder{}
contentBlocks := make(map[int]*ContentBlock)
var finalMessage *ChatCompletionResponse
lastMessage := messages[len(messages)-1]
if messages[len(messages)-1].Role.IsAssistant() {
// this is a continuation of a previous assistant reply, so we'll
// include its contents in the final result
// TODO: handle this at higher level
var firstChunkReceived bool
reader := bufio.NewReader(resp.Body)
for {
select {
case <-ctx.Done():
return nil, ctx.Err()
line, err := reader.ReadBytes('\n')
if err != nil {
if err == io.EOF {
return nil, fmt.Errorf("error reading stream: %w", err)
scanner := bufio.NewScanner(resp.Body)
for scanner.Scan() {
line := scanner.Text()
line = strings.TrimSpace(line)
if len(line) == 0 {
line = bytes.TrimSpace(line)
if len(line) == 0 || !bytes.HasPrefix(line, []byte("data: ")) {
if line[0] == '{' {
var event map[string]interface{}
err := json.Unmarshal([]byte(line), &event)
line = bytes.TrimPrefix(line, []byte("data: "))
var streamEvent StreamEvent
err = json.Unmarshal(line, &streamEvent)
if err != nil {
return nil, fmt.Errorf("failed to unmarshal event data '%s': %v", line, err)
eventType, ok := event["type"].(string)
if !ok {
return nil, fmt.Errorf("invalid event: %s", line)
switch eventType {
case "error":
return nil, fmt.Errorf("an error occurred: %s", event["error"])
return nil, fmt.Errorf("unknown event type: %s", eventType)
} else if strings.HasPrefix(line, "data:") {
data := strings.TrimSpace(strings.TrimPrefix(line, "data:"))
var event map[string]interface{}
err := json.Unmarshal([]byte(data), &event)
if err != nil {
return nil, fmt.Errorf("failed to unmarshal event data: %v", err)
return nil, fmt.Errorf("failed to unmarshal stream event: %w", err)
eventType, ok := event["type"].(string)
if !ok {
return nil, fmt.Errorf("invalid event type")
switch eventType {
switch streamEvent.Type {
case "message_start":
// noop
case "ping":
// signals start of text - currently ignoring
finalMessage = &ChatCompletionResponse{}
err = json.Unmarshal(line, &struct {
Message *ChatCompletionResponse `json:"message"`
}{Message: finalMessage})
if err != nil {
return nil, fmt.Errorf("failed to unmarshal message_start: %w", err)
case "content_block_start":
// ignore?
var contentBlockStart struct {
Index int `json:"index"`
ContentBlock ContentBlock `json:"content_block"`
err = json.Unmarshal(line, &contentBlockStart)
if err != nil {
return nil, fmt.Errorf("failed to unmarshal content_block_start: %w", err)
contentBlocks[contentBlockStart.Index] = &contentBlockStart.ContentBlock
case "content_block_delta":
delta, ok := event["delta"].(map[string]interface{})
if !ok {
return nil, fmt.Errorf("invalid content block delta")
if streamEvent.Index >= len(contentBlocks) {
return nil, fmt.Errorf("received delta for non-existent content block index: %d", streamEvent.Index)
text, ok := delta["text"].(string)
block := contentBlocks[streamEvent.Index]
delta, ok := streamEvent.Delta.(map[string]interface{})
if !ok {
return nil, fmt.Errorf("invalid text delta")
return nil, fmt.Errorf("unexpected delta type: %T", streamEvent.Delta)
deltaType, ok := delta["type"].(string)
if !ok {
return nil, fmt.Errorf("delta missing type field")
switch deltaType {
case "text_delta":
if text, ok := delta["text"].(string); ok {
if !firstChunkReceived {
if prefill == "" {
// if there is no prefil, ensure we trim leading whitespace
text = strings.TrimSpace(text)
firstChunkReceived = true
block.Text += text
output <- api.Chunk{
Content: text,
TokenCount: 1,
case "input_json_delta":
if block.Type != "tool_use" {
return nil, fmt.Errorf("received input_json_delta for non-tool_use block")
if partialJSON, ok := delta["partial_json"].(string); ok {
block.partialJsonAccumulator += partialJSON
case "content_block_stop":
// ignore?
case "message_delta":
delta, ok := event["delta"].(map[string]interface{})
if !ok {
return nil, fmt.Errorf("invalid message delta")
stopReason, ok := delta["stop_reason"].(string)
if ok && stopReason == "stop_sequence" {
stopSequence, ok := delta["stop_sequence"].(string)
if ok && stopSequence == FUNCTION_STOP_SEQUENCE {
content := sb.String()
start := strings.Index(content, "<function_calls>")
if start == -1 {
return nil, fmt.Errorf("reached </function_calls> stop sequence but no opening tag found")
if streamEvent.Index >= len(contentBlocks) {
return nil, fmt.Errorf("received stop for non-existent content block index: %d", streamEvent.Index)
output <- api.Chunk{
TokenCount: 1,
funcCallXml := content[start:] + FUNCTION_STOP_SEQUENCE
var functionCalls XMLFunctionCalls
err := xml.Unmarshal([]byte(funcCallXml), &functionCalls)
block := contentBlocks[streamEvent.Index]
if block.Type == "tool_use" && block.partialJsonAccumulator != "" {
var inputData map[string]interface{}
err := json.Unmarshal([]byte(block.partialJsonAccumulator), &inputData)
if err != nil {
return nil, fmt.Errorf("failed to unmarshal function_calls: %v", err)
return nil, fmt.Errorf("failed to unmarshal accumulated JSON for tool use: %w", err)
block.Input = inputData
case "message_delta":
if finalMessage == nil {
return nil, fmt.Errorf("received message_delta before message_start")
delta, ok := streamEvent.Delta.(map[string]interface{})
if !ok {
return nil, fmt.Errorf("unexpected delta type in message_delta: %T", streamEvent.Delta)
if stopReason, ok := delta["stop_reason"].(string); ok {
finalMessage.StopReason = stopReason
return &api.Message{
Role: api.MessageRoleToolCall,
// function call xml stripped from content for model interop
Content: strings.TrimSpace(content[:start]),
ToolCalls: convertXMLFunctionCallsToToolCalls(functionCalls),
}, nil
case "message_stop":
// return the completed message
content := sb.String()
return &api.Message{
Role: api.MessageRoleAssistant,
Content: content,
}, nil
// End of the stream
case "error":
return nil, fmt.Errorf("an error occurred: %s", event["error"])
return nil, fmt.Errorf("received error event: %v", streamEvent.Message)
fmt.Printf("\nUnrecognized event: %s\n", data)
// Ignore unknown event types
if err := scanner.Err(); err != nil {
return nil, fmt.Errorf("failed to read response body: %v", err)
if finalMessage == nil {
return nil, fmt.Errorf("no final message received")
return nil, fmt.Errorf("unexpected end of stream")
finalMessage.Content = make([]ContentBlock, len(contentBlocks))
for _, v := range contentBlocks {
finalMessage.Content = append(finalMessage.Content, *v)
return convertResponseToMessage(*finalMessage)
func convertResponseToMessage(resp ChatCompletionResponse) (*api.Message, error) {
content := strings.Builder{}
var toolCalls []api.ToolCall
for _, block := range resp.Content {
switch block.Type {
case "text":
case "tool_use":
parameters, ok := block.Input.(map[string]interface{})
if !ok {
return nil, fmt.Errorf("unexpected type for tool call parameters: %T", block.Input)
toolCalls = append(toolCalls, api.ToolCall{
ID: block.ID,
Name: block.Name,
Parameters: parameters,
message := &api.Message{
Role: api.MessageRoleAssistant,
Content: content.String(),
ToolCalls: toolCalls,
if len(toolCalls) > 0 {
message.Role = api.MessageRoleToolCall
return message, nil

View File

@ -1,232 +0,0 @@
package anthropic
import (
const FUNCTION_STOP_SEQUENCE = "</function_calls>"
const TOOL_PREAMBLE = `You have access to the following tools when replying.
You may call them like this:
Here are the tools available:`
const TOOL_PREAMBLE_FOOTER = `Recognize the utility of these tools in a broad range of different applications, and the power they give you to solve a wide range of different problems. However, ensure that the tools are used judiciously and only when clearly relevant to the user's request. Specifically:
1. Only use a tool if the user has explicitly requested or provided information that warrants its use. Do not make assumptions about files or data existing without the user mentioning them.
2. If there is ambiguity about whether using a tool is appropriate, ask a clarifying question to the user before proceeding. Confirm your understanding of their request and intent.
3. Prioritize providing direct responses and explanations based on your own knowledge and understanding. Use tools to supplement and enhance your responses when clearly applicable, but not as a default action.`
type XMLTools struct {
XMLName struct{} `xml:"tools"`
ToolDescriptions []XMLToolDescription `xml:"tool_description"`
type XMLToolDescription struct {
ToolName string `xml:"tool_name"`
Description string `xml:"description"`
Parameters []XMLToolParameter `xml:"parameters>parameter"`
type XMLToolParameter struct {
Name string `xml:"name"`
Type string `xml:"type"`
Description string `xml:"description"`
type XMLFunctionCalls struct {
XMLName struct{} `xml:"function_calls"`
Invoke []XMLFunctionInvoke `xml:"invoke"`
type XMLFunctionInvoke struct {
ToolName string `xml:"tool_name"`
Parameters XMLFunctionInvokeParameters `xml:"parameters"`
type XMLFunctionInvokeParameters struct {
String string `xml:",innerxml"`
type XMLFunctionResults struct {
XMLName struct{} `xml:"function_results"`
Result []XMLFunctionResult `xml:"result"`
type XMLFunctionResult struct {
ToolName string `xml:"tool_name"`
Stdout string `xml:"stdout"`
// accepts raw XML from XMLFunctionInvokeParameters.String, returns map of
// parameters name to value
func parseFunctionParametersXML(params string) map[string]interface{} {
lines := strings.Split(params, "\n")
ret := make(map[string]interface{}, len(lines))
for _, line := range lines {
i := strings.Index(line, ">")
if i == -1 {
j := strings.Index(line, "</")
if j == -1 {
// chop from after opening < to first > to get parameter name,
// then chop after > to first </ to get parameter value
ret[line[1:i]] = line[i+1 : j]
return ret
func convertToolsToXMLTools(tools []api.ToolSpec) XMLTools {
converted := make([]XMLToolDescription, len(tools))
for i, tool := range tools {
converted[i].ToolName = tool.Name
converted[i].Description = tool.Description
params := make([]XMLToolParameter, len(tool.Parameters))
for j, param := range tool.Parameters {
params[j].Name = param.Name
params[j].Description = param.Description
params[j].Type = param.Type
converted[i].Parameters = params
return XMLTools{
ToolDescriptions: converted,
func convertXMLFunctionCallsToToolCalls(functionCalls XMLFunctionCalls) []api.ToolCall {
toolCalls := make([]api.ToolCall, len(functionCalls.Invoke))
for i, invoke := range functionCalls.Invoke {
toolCalls[i].Name = invoke.ToolName
toolCalls[i].Parameters = parseFunctionParametersXML(invoke.Parameters.String)
return toolCalls
func convertToolCallsToXMLFunctionCalls(toolCalls []api.ToolCall) XMLFunctionCalls {
converted := make([]XMLFunctionInvoke, len(toolCalls))
for i, toolCall := range toolCalls {
var params XMLFunctionInvokeParameters
var paramXML string
for key, value := range toolCall.Parameters {
paramXML += fmt.Sprintf("<%s>%v</%s>\n", key, value, key)
params.String = paramXML
converted[i] = XMLFunctionInvoke{
ToolName: toolCall.Name,
Parameters: params,
return XMLFunctionCalls{
Invoke: converted,
func convertToolResultsToXMLFunctionResult(toolResults []api.ToolResult) XMLFunctionResults {
converted := make([]XMLFunctionResult, len(toolResults))
for i, result := range toolResults {
converted[i].ToolName = result.ToolName
converted[i].Stdout = result.Result
return XMLFunctionResults{
Result: converted,
func buildToolsSystemPrompt(tools []api.ToolSpec) string {
xmlTools := convertToolsToXMLTools(tools)
xmlToolsString, err := xmlTools.XMLString()
if err != nil {
panic("Could not serialize []api.Tool to XMLTools")
return TOOL_PREAMBLE + "\n\n" + xmlToolsString + "\n\n" + TOOL_PREAMBLE_FOOTER
func (x XMLTools) XMLString() (string, error) {
tmpl, err := template.New("tools").Parse(`<tools>
{{range .ToolDescriptions}}<tool_description>
{{range .Parameters}}<parameter>
if err != nil {
return "", err
var buf bytes.Buffer
if err := tmpl.Execute(&buf, x); err != nil {
return "", err
return buf.String(), nil
func (x XMLFunctionResults) XMLString() (string, error) {
tmpl, err := template.New("function_results").Parse(`<function_results>
{{range .Result}}<result>
if err != nil {
return "", err
var buf bytes.Buffer
if err := tmpl.Execute(&buf, x); err != nil {
return "", err
return buf.String(), nil
func (x XMLFunctionCalls) XMLString() (string, error) {
tmpl, err := template.New("function_calls").Parse(`<function_calls>
{{range .Invoke}}<invoke>
if err != nil {
return "", err
var buf bytes.Buffer
if err := tmpl.Execute(&buf, x); err != nil {
return "", err
return buf.String(), nil

View File

@ -1,38 +0,0 @@
package anthropic
type AnthropicClient struct {
BaseURL string
APIKey string
type Message struct {
Role string `json:"role"`
Content string `json:"content"`
type Request struct {
Model string `json:"model"`
Messages []Message `json:"messages"`
System string `json:"system,omitempty"`
MaxTokens int `json:"max_tokens,omitempty"`
StopSequences []string `json:"stop_sequences,omitempty"`
Stream bool `json:"stream,omitempty"`
Temperature float32 `json:"temperature,omitempty"`
//TopP float32 `json:"top_p,omitempty"`
//TopK float32 `json:"top_k,omitempty"`
type OriginalContent struct {
Type string `json:"type"`
Text string `json:"text"`
type Response struct {
Id string `json:"id"`
Type string `json:"type"`
Role string `json:"role"`
Content []OriginalContent `json:"content"`
StopReason string `json:"stop_reason"`
StopSequence string `json:"stop_sequence"`

View File

@ -96,7 +96,7 @@ func (c *Context) GetModelProvider(model string) (string, api.ChatCompletionProv
if m == model {
switch *p.Kind {
case "anthropic":
url := ""
url := ""
if p.BaseURL != nil {
url = *p.BaseURL