Support Anthropic's native tool calling API
This commit is contained in:
parent
c50b6b154d
commit
94d84ba7d7
2
TODO.md
2
TODO.md
@ -3,7 +3,7 @@
|
||||
- [x] Strip anthropic XML function call scheme from content, to reconstruct
|
||||
when calling anthropic?
|
||||
- [x] `dir_tree` tool
|
||||
- [ ] Implement native Anthropic API tool calling
|
||||
- [x] Implement native Anthropic API tool calling
|
||||
- [ ] Agents - a name given to a system prompt + set of available tools +
|
||||
potentially other relevent data (e.g. external service credentials, files for
|
||||
RAG, etc), which the user explicitly selects (e.g. `lmcli chat --agent
|
||||
|
@ -40,3 +40,10 @@ type ChatCompletionProvider interface {
|
||||
chunks chan<- Chunk,
|
||||
) (*Message, error)
|
||||
}
|
||||
|
||||
func IsAssistantContinuation(messages []Message) bool {
|
||||
if len(messages) == 0 {
|
||||
return false
|
||||
}
|
||||
return messages[len(messages)-1].Role == MessageRoleAssistant
|
||||
}
|
||||
|
@ -5,100 +5,223 @@ import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"encoding/xml"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"strings"
|
||||
|
||||
"git.mlow.ca/mlow/lmcli/pkg/api"
|
||||
)
|
||||
|
||||
func buildRequest(params api.RequestParameters, messages []api.Message) Request {
|
||||
requestBody := Request{
|
||||
Model: params.Model,
|
||||
Messages: make([]Message, len(messages)),
|
||||
MaxTokens: params.MaxTokens,
|
||||
Temperature: params.Temperature,
|
||||
Stream: false,
|
||||
const ANTHROPIC_VERSION = "2023-06-01"
|
||||
|
||||
StopSequences: []string{
|
||||
FUNCTION_STOP_SEQUENCE,
|
||||
"\n\nHuman:",
|
||||
},
|
||||
type AnthropicClient struct {
|
||||
APIKey string
|
||||
BaseURL string
|
||||
}
|
||||
|
||||
startIdx := 0
|
||||
if len(messages) > 0 && messages[0].Role == api.MessageRoleSystem {
|
||||
requestBody.System = messages[0].Content
|
||||
requestBody.Messages = requestBody.Messages[1:]
|
||||
startIdx = 1
|
||||
type ChatCompletionMessage struct {
|
||||
Role string `json:"role"`
|
||||
Content interface{} `json:"content"`
|
||||
}
|
||||
|
||||
type Tool struct {
|
||||
Name string `json:"name"`
|
||||
Description string `json:"description"`
|
||||
InputSchema InputSchema `json:"input_schema"`
|
||||
}
|
||||
|
||||
type InputSchema struct {
|
||||
Type string `json:"type"`
|
||||
Properties map[string]Property `json:"properties"`
|
||||
Required []string `json:"required"`
|
||||
}
|
||||
|
||||
type Property struct {
|
||||
Type string `json:"type"`
|
||||
Description string `json:"description"`
|
||||
Enum []string `json:"enum,omitempty"`
|
||||
}
|
||||
|
||||
type ChatCompletionRequest struct {
|
||||
Model string `json:"model"`
|
||||
Messages []ChatCompletionMessage `json:"messages"`
|
||||
System string `json:"system,omitempty"`
|
||||
Tools []Tool `json:"tools,omitempty"`
|
||||
MaxTokens int `json:"max_tokens"`
|
||||
Temperature float32 `json:"temperature,omitempty"`
|
||||
Stream bool `json:"stream"`
|
||||
}
|
||||
|
||||
type ContentBlock struct {
|
||||
Type string `json:"type"`
|
||||
Text string `json:"text,omitempty"`
|
||||
ID string `json:"id,omitempty"`
|
||||
Name string `json:"name,omitempty"`
|
||||
Input interface{} `json:"input,omitempty"`
|
||||
partialJsonAccumulator string
|
||||
}
|
||||
|
||||
type ChatCompletionResponse struct {
|
||||
ID string `json:"id"`
|
||||
Type string `json:"type"`
|
||||
Role string `json:"role"`
|
||||
Model string `json:"model"`
|
||||
Content []ContentBlock `json:"content"`
|
||||
StopReason string `json:"stop_reason"`
|
||||
Usage Usage `json:"usage"`
|
||||
}
|
||||
|
||||
type Usage struct {
|
||||
InputTokens int `json:"input_tokens"`
|
||||
OutputTokens int `json:"output_tokens"`
|
||||
}
|
||||
|
||||
type StreamEvent struct {
|
||||
Type string `json:"type"`
|
||||
Message interface{} `json:"message,omitempty"`
|
||||
Index int `json:"index,omitempty"`
|
||||
Delta interface{} `json:"delta,omitempty"`
|
||||
}
|
||||
|
||||
func convertTools(tools []api.ToolSpec) []Tool {
|
||||
anthropicTools := make([]Tool, len(tools))
|
||||
for i, tool := range tools {
|
||||
properties := make(map[string]Property)
|
||||
for _, param := range tool.Parameters {
|
||||
properties[param.Name] = Property{
|
||||
Type: param.Type,
|
||||
Description: param.Description,
|
||||
Enum: param.Enum,
|
||||
}
|
||||
}
|
||||
|
||||
var required []string
|
||||
for _, param := range tool.Parameters {
|
||||
if param.Required {
|
||||
required = append(required, param.Name)
|
||||
}
|
||||
}
|
||||
|
||||
anthropicTools[i] = Tool{
|
||||
Name: tool.Name,
|
||||
Description: tool.Description,
|
||||
InputSchema: InputSchema{
|
||||
Type: "object",
|
||||
Properties: properties,
|
||||
Required: required,
|
||||
},
|
||||
}
|
||||
}
|
||||
return anthropicTools
|
||||
}
|
||||
|
||||
func createChatCompletionRequest(
|
||||
params api.RequestParameters,
|
||||
messages []api.Message,
|
||||
) (string, ChatCompletionRequest) {
|
||||
requestMessages := make([]ChatCompletionMessage, 0, len(messages))
|
||||
var systemMessage string
|
||||
|
||||
for _, m := range messages {
|
||||
if m.Role == api.MessageRoleSystem {
|
||||
systemMessage = m.Content
|
||||
continue
|
||||
}
|
||||
|
||||
var content interface{}
|
||||
role := string(m.Role)
|
||||
|
||||
switch m.Role {
|
||||
case api.MessageRoleToolCall:
|
||||
role = "assistant"
|
||||
contentBlocks := make([]map[string]interface{}, 0)
|
||||
if m.Content != "" {
|
||||
contentBlocks = append(contentBlocks, map[string]interface{}{
|
||||
"type": "text",
|
||||
"text": m.Content,
|
||||
})
|
||||
}
|
||||
for _, toolCall := range m.ToolCalls {
|
||||
contentBlocks = append(contentBlocks, map[string]interface{}{
|
||||
"type": "tool_use",
|
||||
"id": toolCall.ID,
|
||||
"name": toolCall.Name,
|
||||
"input": toolCall.Parameters,
|
||||
})
|
||||
}
|
||||
content = contentBlocks
|
||||
|
||||
case api.MessageRoleToolResult:
|
||||
role = "user"
|
||||
contentBlocks := make([]map[string]interface{}, 0)
|
||||
for _, result := range m.ToolResults {
|
||||
contentBlock := map[string]interface{}{
|
||||
"type": "tool_result",
|
||||
"tool_use_id": result.ToolCallID,
|
||||
"content": result.Result,
|
||||
}
|
||||
contentBlocks = append(contentBlocks, contentBlock)
|
||||
}
|
||||
content = contentBlocks
|
||||
|
||||
default:
|
||||
content = m.Content
|
||||
}
|
||||
|
||||
requestMessages = append(requestMessages, ChatCompletionMessage{
|
||||
Role: role,
|
||||
Content: content,
|
||||
})
|
||||
}
|
||||
|
||||
request := ChatCompletionRequest{
|
||||
Model: params.Model,
|
||||
Messages: requestMessages,
|
||||
System: systemMessage,
|
||||
MaxTokens: params.MaxTokens,
|
||||
Temperature: params.Temperature,
|
||||
}
|
||||
|
||||
if len(params.ToolBag) > 0 {
|
||||
if len(requestBody.System) > 0 {
|
||||
// add a divider between existing system prompt and tools
|
||||
requestBody.System += "\n\n---\n\n"
|
||||
}
|
||||
requestBody.System += buildToolsSystemPrompt(params.ToolBag)
|
||||
request.Tools = convertTools(params.ToolBag)
|
||||
}
|
||||
|
||||
for i, msg := range messages[startIdx:] {
|
||||
message := &requestBody.Messages[i]
|
||||
|
||||
switch msg.Role {
|
||||
case api.MessageRoleToolCall:
|
||||
message.Role = "assistant"
|
||||
if msg.Content != "" {
|
||||
message.Content = msg.Content
|
||||
}
|
||||
xmlFuncCalls := convertToolCallsToXMLFunctionCalls(msg.ToolCalls)
|
||||
xmlString, err := xmlFuncCalls.XMLString()
|
||||
if err != nil {
|
||||
panic("Could not serialize []ToolCall to XMLFunctionCall")
|
||||
}
|
||||
if len(message.Content) > 0 {
|
||||
message.Content += fmt.Sprintf("\n\n%s", xmlString)
|
||||
} else {
|
||||
message.Content = xmlString
|
||||
}
|
||||
case api.MessageRoleToolResult:
|
||||
xmlFuncResults := convertToolResultsToXMLFunctionResult(msg.ToolResults)
|
||||
xmlString, err := xmlFuncResults.XMLString()
|
||||
if err != nil {
|
||||
panic("Could not serialize []ToolResult to XMLFunctionResults")
|
||||
}
|
||||
message.Role = "user"
|
||||
message.Content = xmlString
|
||||
default:
|
||||
message.Role = string(msg.Role)
|
||||
message.Content = msg.Content
|
||||
}
|
||||
}
|
||||
return requestBody
|
||||
var prefill string
|
||||
if api.IsAssistantContinuation(messages) {
|
||||
prefill = messages[len(messages)-1].Content
|
||||
}
|
||||
|
||||
func sendRequest(ctx context.Context, c *AnthropicClient, r Request) (*http.Response, error) {
|
||||
jsonBody, err := json.Marshal(r)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to marshal request body: %v", err)
|
||||
return prefill, request
|
||||
}
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, "POST", c.BaseURL+"/messages", bytes.NewBuffer(jsonBody))
|
||||
func (c *AnthropicClient) sendRequest(ctx context.Context, r ChatCompletionRequest) (*http.Response, error) {
|
||||
jsonData, err := json.Marshal(r)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create HTTP request: %v", err)
|
||||
return nil, fmt.Errorf("failed to marshal request: %w", err)
|
||||
}
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, "POST", c.BaseURL+"/v1/messages", bytes.NewBuffer(jsonData))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create HTTP request: %w", err)
|
||||
}
|
||||
|
||||
req.Header.Set("x-api-key", c.APIKey)
|
||||
req.Header.Set("anthropic-version", "2023-06-01")
|
||||
req.Header.Set("anthropic-version", ANTHROPIC_VERSION)
|
||||
req.Header.Set("content-type", "application/json")
|
||||
|
||||
client := &http.Client{}
|
||||
resp, err := client.Do(req)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to send HTTP request: %v", err)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return resp, nil
|
||||
if resp.StatusCode != 200 {
|
||||
bytes, _ := io.ReadAll(resp.Body)
|
||||
return resp, fmt.Errorf("%v", string(bytes))
|
||||
}
|
||||
|
||||
return resp, err
|
||||
}
|
||||
|
||||
func (c *AnthropicClient) CreateChatCompletion(
|
||||
@ -107,45 +230,25 @@ func (c *AnthropicClient) CreateChatCompletion(
|
||||
messages []api.Message,
|
||||
) (*api.Message, error) {
|
||||
if len(messages) == 0 {
|
||||
return nil, fmt.Errorf("Can't create completion from no messages")
|
||||
return nil, fmt.Errorf("can't create completion from no messages")
|
||||
}
|
||||
|
||||
request := buildRequest(params, messages)
|
||||
_, req := createChatCompletionRequest(params, messages)
|
||||
req.Stream = false
|
||||
|
||||
resp, err := sendRequest(ctx, c, request)
|
||||
resp, err := c.sendRequest(ctx, req)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
return nil, fmt.Errorf("failed to send request: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
var response Response
|
||||
err = json.NewDecoder(resp.Body).Decode(&response)
|
||||
var completionResp ChatCompletionResponse
|
||||
err = json.NewDecoder(resp.Body).Decode(&completionResp)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to decode response: %v", err)
|
||||
return nil, fmt.Errorf("failed to decode response: %w", err)
|
||||
}
|
||||
|
||||
sb := strings.Builder{}
|
||||
|
||||
lastMessage := messages[len(messages)-1]
|
||||
if lastMessage.Role.IsAssistant() {
|
||||
// this is a continuation of a previous assistant reply, so we'll
|
||||
// include its contents in the final result
|
||||
sb.WriteString(lastMessage.Content)
|
||||
}
|
||||
|
||||
for _, content := range response.Content {
|
||||
switch content.Type {
|
||||
case "text":
|
||||
sb.WriteString(content.Text)
|
||||
default:
|
||||
return nil, fmt.Errorf("unsupported message type: %s", content.Type)
|
||||
}
|
||||
}
|
||||
|
||||
return &api.Message{
|
||||
Role: api.MessageRoleAssistant,
|
||||
Content: sb.String(),
|
||||
}, nil
|
||||
return convertResponseToMessage(completionResp)
|
||||
}
|
||||
|
||||
func (c *AnthropicClient) CreateChatCompletionStream(
|
||||
@ -155,144 +258,193 @@ func (c *AnthropicClient) CreateChatCompletionStream(
|
||||
output chan<- api.Chunk,
|
||||
) (*api.Message, error) {
|
||||
if len(messages) == 0 {
|
||||
return nil, fmt.Errorf("Can't create completion from no messages")
|
||||
return nil, fmt.Errorf("can't create completion from no messages")
|
||||
}
|
||||
|
||||
request := buildRequest(params, messages)
|
||||
request.Stream = true
|
||||
prefill, req := createChatCompletionRequest(params, messages)
|
||||
req.Stream = true
|
||||
|
||||
resp, err := sendRequest(ctx, c, request)
|
||||
resp, err := c.sendRequest(ctx, req)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
return nil, fmt.Errorf("failed to send request: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
sb := strings.Builder{}
|
||||
contentBlocks := make(map[int]*ContentBlock)
|
||||
var finalMessage *ChatCompletionResponse
|
||||
|
||||
lastMessage := messages[len(messages)-1]
|
||||
if messages[len(messages)-1].Role.IsAssistant() {
|
||||
// this is a continuation of a previous assistant reply, so we'll
|
||||
// include its contents in the final result
|
||||
// TODO: handle this at higher level
|
||||
sb.WriteString(lastMessage.Content)
|
||||
var firstChunkReceived bool
|
||||
|
||||
reader := bufio.NewReader(resp.Body)
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return nil, ctx.Err()
|
||||
default:
|
||||
line, err := reader.ReadBytes('\n')
|
||||
if err != nil {
|
||||
if err == io.EOF {
|
||||
break
|
||||
}
|
||||
return nil, fmt.Errorf("error reading stream: %w", err)
|
||||
}
|
||||
|
||||
scanner := bufio.NewScanner(resp.Body)
|
||||
for scanner.Scan() {
|
||||
line := scanner.Text()
|
||||
line = strings.TrimSpace(line)
|
||||
|
||||
if len(line) == 0 {
|
||||
line = bytes.TrimSpace(line)
|
||||
if len(line) == 0 || !bytes.HasPrefix(line, []byte("data: ")) {
|
||||
continue
|
||||
}
|
||||
|
||||
if line[0] == '{' {
|
||||
var event map[string]interface{}
|
||||
err := json.Unmarshal([]byte(line), &event)
|
||||
line = bytes.TrimPrefix(line, []byte("data: "))
|
||||
|
||||
var streamEvent StreamEvent
|
||||
err = json.Unmarshal(line, &streamEvent)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to unmarshal event data '%s': %v", line, err)
|
||||
}
|
||||
eventType, ok := event["type"].(string)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("invalid event: %s", line)
|
||||
}
|
||||
switch eventType {
|
||||
case "error":
|
||||
return nil, fmt.Errorf("an error occurred: %s", event["error"])
|
||||
default:
|
||||
return nil, fmt.Errorf("unknown event type: %s", eventType)
|
||||
}
|
||||
} else if strings.HasPrefix(line, "data:") {
|
||||
data := strings.TrimSpace(strings.TrimPrefix(line, "data:"))
|
||||
var event map[string]interface{}
|
||||
err := json.Unmarshal([]byte(data), &event)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to unmarshal event data: %v", err)
|
||||
return nil, fmt.Errorf("failed to unmarshal stream event: %w", err)
|
||||
}
|
||||
|
||||
eventType, ok := event["type"].(string)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("invalid event type")
|
||||
}
|
||||
|
||||
switch eventType {
|
||||
switch streamEvent.Type {
|
||||
case "message_start":
|
||||
// noop
|
||||
case "ping":
|
||||
// signals start of text - currently ignoring
|
||||
finalMessage = &ChatCompletionResponse{}
|
||||
err = json.Unmarshal(line, &struct {
|
||||
Message *ChatCompletionResponse `json:"message"`
|
||||
}{Message: finalMessage})
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to unmarshal message_start: %w", err)
|
||||
}
|
||||
case "content_block_start":
|
||||
// ignore?
|
||||
var contentBlockStart struct {
|
||||
Index int `json:"index"`
|
||||
ContentBlock ContentBlock `json:"content_block"`
|
||||
}
|
||||
err = json.Unmarshal(line, &contentBlockStart)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to unmarshal content_block_start: %w", err)
|
||||
}
|
||||
|
||||
contentBlocks[contentBlockStart.Index] = &contentBlockStart.ContentBlock
|
||||
case "content_block_delta":
|
||||
delta, ok := event["delta"].(map[string]interface{})
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("invalid content block delta")
|
||||
if streamEvent.Index >= len(contentBlocks) {
|
||||
return nil, fmt.Errorf("received delta for non-existent content block index: %d", streamEvent.Index)
|
||||
}
|
||||
text, ok := delta["text"].(string)
|
||||
|
||||
block := contentBlocks[streamEvent.Index]
|
||||
delta, ok := streamEvent.Delta.(map[string]interface{})
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("invalid text delta")
|
||||
return nil, fmt.Errorf("unexpected delta type: %T", streamEvent.Delta)
|
||||
}
|
||||
sb.WriteString(text)
|
||||
|
||||
deltaType, ok := delta["type"].(string)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("delta missing type field")
|
||||
}
|
||||
|
||||
switch deltaType {
|
||||
case "text_delta":
|
||||
if text, ok := delta["text"].(string); ok {
|
||||
if !firstChunkReceived {
|
||||
if prefill == "" {
|
||||
// if there is no prefil, ensure we trim leading whitespace
|
||||
text = strings.TrimSpace(text)
|
||||
}
|
||||
firstChunkReceived = true
|
||||
}
|
||||
block.Text += text
|
||||
output <- api.Chunk{
|
||||
Content: text,
|
||||
TokenCount: 1,
|
||||
}
|
||||
}
|
||||
case "input_json_delta":
|
||||
if block.Type != "tool_use" {
|
||||
return nil, fmt.Errorf("received input_json_delta for non-tool_use block")
|
||||
}
|
||||
if partialJSON, ok := delta["partial_json"].(string); ok {
|
||||
block.partialJsonAccumulator += partialJSON
|
||||
}
|
||||
}
|
||||
|
||||
case "content_block_stop":
|
||||
// ignore?
|
||||
case "message_delta":
|
||||
delta, ok := event["delta"].(map[string]interface{})
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("invalid message delta")
|
||||
}
|
||||
stopReason, ok := delta["stop_reason"].(string)
|
||||
if ok && stopReason == "stop_sequence" {
|
||||
stopSequence, ok := delta["stop_sequence"].(string)
|
||||
if ok && stopSequence == FUNCTION_STOP_SEQUENCE {
|
||||
content := sb.String()
|
||||
|
||||
start := strings.Index(content, "<function_calls>")
|
||||
if start == -1 {
|
||||
return nil, fmt.Errorf("reached </function_calls> stop sequence but no opening tag found")
|
||||
if streamEvent.Index >= len(contentBlocks) {
|
||||
return nil, fmt.Errorf("received stop for non-existent content block index: %d", streamEvent.Index)
|
||||
}
|
||||
|
||||
sb.WriteString(FUNCTION_STOP_SEQUENCE)
|
||||
output <- api.Chunk{
|
||||
Content: FUNCTION_STOP_SEQUENCE,
|
||||
TokenCount: 1,
|
||||
}
|
||||
funcCallXml := content[start:] + FUNCTION_STOP_SEQUENCE
|
||||
|
||||
var functionCalls XMLFunctionCalls
|
||||
err := xml.Unmarshal([]byte(funcCallXml), &functionCalls)
|
||||
block := contentBlocks[streamEvent.Index]
|
||||
if block.Type == "tool_use" && block.partialJsonAccumulator != "" {
|
||||
var inputData map[string]interface{}
|
||||
err := json.Unmarshal([]byte(block.partialJsonAccumulator), &inputData)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to unmarshal function_calls: %v", err)
|
||||
return nil, fmt.Errorf("failed to unmarshal accumulated JSON for tool use: %w", err)
|
||||
}
|
||||
block.Input = inputData
|
||||
}
|
||||
case "message_delta":
|
||||
if finalMessage == nil {
|
||||
return nil, fmt.Errorf("received message_delta before message_start")
|
||||
}
|
||||
delta, ok := streamEvent.Delta.(map[string]interface{})
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("unexpected delta type in message_delta: %T", streamEvent.Delta)
|
||||
}
|
||||
if stopReason, ok := delta["stop_reason"].(string); ok {
|
||||
finalMessage.StopReason = stopReason
|
||||
}
|
||||
|
||||
return &api.Message{
|
||||
Role: api.MessageRoleToolCall,
|
||||
// function call xml stripped from content for model interop
|
||||
Content: strings.TrimSpace(content[:start]),
|
||||
ToolCalls: convertXMLFunctionCallsToToolCalls(functionCalls),
|
||||
}, nil
|
||||
}
|
||||
}
|
||||
case "message_stop":
|
||||
// return the completed message
|
||||
content := sb.String()
|
||||
return &api.Message{
|
||||
Role: api.MessageRoleAssistant,
|
||||
Content: content,
|
||||
}, nil
|
||||
// End of the stream
|
||||
goto END_STREAM
|
||||
|
||||
case "error":
|
||||
return nil, fmt.Errorf("an error occurred: %s", event["error"])
|
||||
return nil, fmt.Errorf("received error event: %v", streamEvent.Message)
|
||||
|
||||
default:
|
||||
fmt.Printf("\nUnrecognized event: %s\n", data)
|
||||
// Ignore unknown event types
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if err := scanner.Err(); err != nil {
|
||||
return nil, fmt.Errorf("failed to read response body: %v", err)
|
||||
END_STREAM:
|
||||
if finalMessage == nil {
|
||||
return nil, fmt.Errorf("no final message received")
|
||||
}
|
||||
|
||||
return nil, fmt.Errorf("unexpected end of stream")
|
||||
finalMessage.Content = make([]ContentBlock, len(contentBlocks))
|
||||
for _, v := range contentBlocks {
|
||||
finalMessage.Content = append(finalMessage.Content, *v)
|
||||
}
|
||||
|
||||
return convertResponseToMessage(*finalMessage)
|
||||
}
|
||||
|
||||
func convertResponseToMessage(resp ChatCompletionResponse) (*api.Message, error) {
|
||||
content := strings.Builder{}
|
||||
var toolCalls []api.ToolCall
|
||||
|
||||
for _, block := range resp.Content {
|
||||
switch block.Type {
|
||||
case "text":
|
||||
content.WriteString(block.Text)
|
||||
case "tool_use":
|
||||
parameters, ok := block.Input.(map[string]interface{})
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("unexpected type for tool call parameters: %T", block.Input)
|
||||
}
|
||||
toolCalls = append(toolCalls, api.ToolCall{
|
||||
ID: block.ID,
|
||||
Name: block.Name,
|
||||
Parameters: parameters,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
message := &api.Message{
|
||||
Role: api.MessageRoleAssistant,
|
||||
Content: content.String(),
|
||||
ToolCalls: toolCalls,
|
||||
}
|
||||
|
||||
if len(toolCalls) > 0 {
|
||||
message.Role = api.MessageRoleToolCall
|
||||
}
|
||||
|
||||
return message, nil
|
||||
}
|
||||
|
@ -1,232 +0,0 @@
|
||||
package anthropic
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"fmt"
|
||||
"strings"
|
||||
"text/template"
|
||||
|
||||
"git.mlow.ca/mlow/lmcli/pkg/api"
|
||||
)
|
||||
|
||||
const FUNCTION_STOP_SEQUENCE = "</function_calls>"
|
||||
|
||||
const TOOL_PREAMBLE = `You have access to the following tools when replying.
|
||||
|
||||
You may call them like this:
|
||||
|
||||
<function_calls>
|
||||
<invoke>
|
||||
<tool_name>$TOOL_NAME</tool_name>
|
||||
<parameters>
|
||||
<$PARAMETER_NAME>$PARAMETER_VALUE</$PARAMETER_NAME>
|
||||
...
|
||||
</parameters>
|
||||
</invoke>
|
||||
</function_calls>
|
||||
|
||||
Here are the tools available:`
|
||||
|
||||
const TOOL_PREAMBLE_FOOTER = `Recognize the utility of these tools in a broad range of different applications, and the power they give you to solve a wide range of different problems. However, ensure that the tools are used judiciously and only when clearly relevant to the user's request. Specifically:
|
||||
|
||||
1. Only use a tool if the user has explicitly requested or provided information that warrants its use. Do not make assumptions about files or data existing without the user mentioning them.
|
||||
|
||||
2. If there is ambiguity about whether using a tool is appropriate, ask a clarifying question to the user before proceeding. Confirm your understanding of their request and intent.
|
||||
|
||||
3. Prioritize providing direct responses and explanations based on your own knowledge and understanding. Use tools to supplement and enhance your responses when clearly applicable, but not as a default action.`
|
||||
|
||||
type XMLTools struct {
|
||||
XMLName struct{} `xml:"tools"`
|
||||
ToolDescriptions []XMLToolDescription `xml:"tool_description"`
|
||||
}
|
||||
|
||||
type XMLToolDescription struct {
|
||||
ToolName string `xml:"tool_name"`
|
||||
Description string `xml:"description"`
|
||||
Parameters []XMLToolParameter `xml:"parameters>parameter"`
|
||||
}
|
||||
|
||||
type XMLToolParameter struct {
|
||||
Name string `xml:"name"`
|
||||
Type string `xml:"type"`
|
||||
Description string `xml:"description"`
|
||||
}
|
||||
|
||||
type XMLFunctionCalls struct {
|
||||
XMLName struct{} `xml:"function_calls"`
|
||||
Invoke []XMLFunctionInvoke `xml:"invoke"`
|
||||
}
|
||||
|
||||
type XMLFunctionInvoke struct {
|
||||
ToolName string `xml:"tool_name"`
|
||||
Parameters XMLFunctionInvokeParameters `xml:"parameters"`
|
||||
}
|
||||
|
||||
type XMLFunctionInvokeParameters struct {
|
||||
String string `xml:",innerxml"`
|
||||
}
|
||||
|
||||
type XMLFunctionResults struct {
|
||||
XMLName struct{} `xml:"function_results"`
|
||||
Result []XMLFunctionResult `xml:"result"`
|
||||
}
|
||||
|
||||
type XMLFunctionResult struct {
|
||||
ToolName string `xml:"tool_name"`
|
||||
Stdout string `xml:"stdout"`
|
||||
}
|
||||
|
||||
// accepts raw XML from XMLFunctionInvokeParameters.String, returns map of
|
||||
// parameters name to value
|
||||
func parseFunctionParametersXML(params string) map[string]interface{} {
|
||||
lines := strings.Split(params, "\n")
|
||||
ret := make(map[string]interface{}, len(lines))
|
||||
for _, line := range lines {
|
||||
i := strings.Index(line, ">")
|
||||
if i == -1 {
|
||||
continue
|
||||
}
|
||||
j := strings.Index(line, "</")
|
||||
if j == -1 {
|
||||
continue
|
||||
}
|
||||
// chop from after opening < to first > to get parameter name,
|
||||
// then chop after > to first </ to get parameter value
|
||||
ret[line[1:i]] = line[i+1 : j]
|
||||
}
|
||||
return ret
|
||||
}
|
||||
|
||||
func convertToolsToXMLTools(tools []api.ToolSpec) XMLTools {
|
||||
converted := make([]XMLToolDescription, len(tools))
|
||||
for i, tool := range tools {
|
||||
converted[i].ToolName = tool.Name
|
||||
converted[i].Description = tool.Description
|
||||
|
||||
params := make([]XMLToolParameter, len(tool.Parameters))
|
||||
for j, param := range tool.Parameters {
|
||||
params[j].Name = param.Name
|
||||
params[j].Description = param.Description
|
||||
params[j].Type = param.Type
|
||||
}
|
||||
|
||||
converted[i].Parameters = params
|
||||
}
|
||||
return XMLTools{
|
||||
ToolDescriptions: converted,
|
||||
}
|
||||
}
|
||||
|
||||
func convertXMLFunctionCallsToToolCalls(functionCalls XMLFunctionCalls) []api.ToolCall {
|
||||
toolCalls := make([]api.ToolCall, len(functionCalls.Invoke))
|
||||
for i, invoke := range functionCalls.Invoke {
|
||||
toolCalls[i].Name = invoke.ToolName
|
||||
toolCalls[i].Parameters = parseFunctionParametersXML(invoke.Parameters.String)
|
||||
}
|
||||
return toolCalls
|
||||
}
|
||||
|
||||
func convertToolCallsToXMLFunctionCalls(toolCalls []api.ToolCall) XMLFunctionCalls {
|
||||
converted := make([]XMLFunctionInvoke, len(toolCalls))
|
||||
for i, toolCall := range toolCalls {
|
||||
var params XMLFunctionInvokeParameters
|
||||
var paramXML string
|
||||
for key, value := range toolCall.Parameters {
|
||||
paramXML += fmt.Sprintf("<%s>%v</%s>\n", key, value, key)
|
||||
}
|
||||
params.String = paramXML
|
||||
converted[i] = XMLFunctionInvoke{
|
||||
ToolName: toolCall.Name,
|
||||
Parameters: params,
|
||||
}
|
||||
}
|
||||
return XMLFunctionCalls{
|
||||
Invoke: converted,
|
||||
}
|
||||
}
|
||||
|
||||
func convertToolResultsToXMLFunctionResult(toolResults []api.ToolResult) XMLFunctionResults {
|
||||
converted := make([]XMLFunctionResult, len(toolResults))
|
||||
for i, result := range toolResults {
|
||||
converted[i].ToolName = result.ToolName
|
||||
converted[i].Stdout = result.Result
|
||||
}
|
||||
return XMLFunctionResults{
|
||||
Result: converted,
|
||||
}
|
||||
}
|
||||
|
||||
func buildToolsSystemPrompt(tools []api.ToolSpec) string {
|
||||
xmlTools := convertToolsToXMLTools(tools)
|
||||
xmlToolsString, err := xmlTools.XMLString()
|
||||
if err != nil {
|
||||
panic("Could not serialize []api.Tool to XMLTools")
|
||||
}
|
||||
return TOOL_PREAMBLE + "\n\n" + xmlToolsString + "\n\n" + TOOL_PREAMBLE_FOOTER
|
||||
}
|
||||
|
||||
func (x XMLTools) XMLString() (string, error) {
|
||||
tmpl, err := template.New("tools").Parse(`<tools>
|
||||
{{range .ToolDescriptions}}<tool_description>
|
||||
<tool_name>{{.ToolName}}</tool_name>
|
||||
<description>
|
||||
{{.Description}}
|
||||
</description>
|
||||
<parameters>
|
||||
{{range .Parameters}}<parameter>
|
||||
<name>{{.Name}}</name>
|
||||
<type>{{.Type}}</type>
|
||||
<description>{{.Description}}</description>
|
||||
</parameter>
|
||||
{{end}}</parameters>
|
||||
</tool_description>
|
||||
{{end}}</tools>`)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
var buf bytes.Buffer
|
||||
if err := tmpl.Execute(&buf, x); err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
return buf.String(), nil
|
||||
}
|
||||
|
||||
func (x XMLFunctionResults) XMLString() (string, error) {
|
||||
tmpl, err := template.New("function_results").Parse(`<function_results>
|
||||
{{range .Result}}<result>
|
||||
<tool_name>{{.ToolName}}</tool_name>
|
||||
<stdout>{{.Stdout}}</stdout>
|
||||
</result>
|
||||
{{end}}</function_results>`)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
var buf bytes.Buffer
|
||||
if err := tmpl.Execute(&buf, x); err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
return buf.String(), nil
|
||||
}
|
||||
|
||||
func (x XMLFunctionCalls) XMLString() (string, error) {
|
||||
tmpl, err := template.New("function_calls").Parse(`<function_calls>
|
||||
{{range .Invoke}}<invoke>
|
||||
<tool_name>{{.ToolName}}</tool_name>
|
||||
<parameters>{{.Parameters.String}}</parameters>
|
||||
</invoke>
|
||||
{{end}}</function_calls>`)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
var buf bytes.Buffer
|
||||
if err := tmpl.Execute(&buf, x); err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
return buf.String(), nil
|
||||
}
|
@ -1,38 +0,0 @@
|
||||
package anthropic
|
||||
|
||||
type AnthropicClient struct {
|
||||
BaseURL string
|
||||
APIKey string
|
||||
}
|
||||
|
||||
type Message struct {
|
||||
Role string `json:"role"`
|
||||
Content string `json:"content"`
|
||||
}
|
||||
|
||||
type Request struct {
|
||||
Model string `json:"model"`
|
||||
Messages []Message `json:"messages"`
|
||||
System string `json:"system,omitempty"`
|
||||
MaxTokens int `json:"max_tokens,omitempty"`
|
||||
StopSequences []string `json:"stop_sequences,omitempty"`
|
||||
Stream bool `json:"stream,omitempty"`
|
||||
Temperature float32 `json:"temperature,omitempty"`
|
||||
//TopP float32 `json:"top_p,omitempty"`
|
||||
//TopK float32 `json:"top_k,omitempty"`
|
||||
}
|
||||
|
||||
type OriginalContent struct {
|
||||
Type string `json:"type"`
|
||||
Text string `json:"text"`
|
||||
}
|
||||
|
||||
type Response struct {
|
||||
Id string `json:"id"`
|
||||
Type string `json:"type"`
|
||||
Role string `json:"role"`
|
||||
Content []OriginalContent `json:"content"`
|
||||
StopReason string `json:"stop_reason"`
|
||||
StopSequence string `json:"stop_sequence"`
|
||||
}
|
||||
|
@ -96,7 +96,7 @@ func (c *Context) GetModelProvider(model string) (string, api.ChatCompletionProv
|
||||
if m == model {
|
||||
switch *p.Kind {
|
||||
case "anthropic":
|
||||
url := "https://api.anthropic.com/v1"
|
||||
url := "https://api.anthropic.com"
|
||||
if p.BaseURL != nil {
|
||||
url = *p.BaseURL
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user