454 lines
12 KiB
Go
454 lines
12 KiB
Go
package anthropic
|
|
|
|
import (
|
|
"bufio"
|
|
"bytes"
|
|
"context"
|
|
"encoding/json"
|
|
"fmt"
|
|
"io"
|
|
"net/http"
|
|
"strings"
|
|
|
|
"git.mlow.ca/mlow/lmcli/pkg/api"
|
|
"git.mlow.ca/mlow/lmcli/pkg/api/provider"
|
|
)
|
|
|
|
const ANTHROPIC_VERSION = "2023-06-01"
|
|
|
|
type AnthropicClient struct {
|
|
APIKey string
|
|
BaseURL string
|
|
}
|
|
|
|
type ChatCompletionMessage struct {
|
|
Role string `json:"role"`
|
|
Content interface{} `json:"content"`
|
|
}
|
|
|
|
type Tool struct {
|
|
Name string `json:"name"`
|
|
Description string `json:"description"`
|
|
InputSchema InputSchema `json:"input_schema"`
|
|
}
|
|
|
|
type InputSchema struct {
|
|
Type string `json:"type"`
|
|
Properties map[string]Property `json:"properties"`
|
|
Required []string `json:"required"`
|
|
}
|
|
|
|
type Property struct {
|
|
Type string `json:"type"`
|
|
Description string `json:"description"`
|
|
Enum []string `json:"enum,omitempty"`
|
|
}
|
|
|
|
type ChatCompletionRequest struct {
|
|
Model string `json:"model"`
|
|
Messages []ChatCompletionMessage `json:"messages"`
|
|
System string `json:"system,omitempty"`
|
|
Tools []Tool `json:"tools,omitempty"`
|
|
MaxTokens int `json:"max_tokens"`
|
|
Temperature float32 `json:"temperature,omitempty"`
|
|
Stream bool `json:"stream"`
|
|
}
|
|
|
|
type ContentBlock struct {
|
|
Type string `json:"type"`
|
|
Text string `json:"text,omitempty"`
|
|
ID string `json:"id,omitempty"`
|
|
Name string `json:"name,omitempty"`
|
|
Input interface{} `json:"input,omitempty"`
|
|
partialJsonAccumulator string
|
|
}
|
|
|
|
type ChatCompletionResponse struct {
|
|
ID string `json:"id"`
|
|
Type string `json:"type"`
|
|
Role string `json:"role"`
|
|
Model string `json:"model"`
|
|
Content []ContentBlock `json:"content"`
|
|
StopReason string `json:"stop_reason"`
|
|
Usage Usage `json:"usage"`
|
|
}
|
|
|
|
type Usage struct {
|
|
InputTokens int `json:"input_tokens"`
|
|
OutputTokens int `json:"output_tokens"`
|
|
}
|
|
|
|
type StreamEvent struct {
|
|
Type string `json:"type"`
|
|
Message interface{} `json:"message,omitempty"`
|
|
Index int `json:"index,omitempty"`
|
|
Delta interface{} `json:"delta,omitempty"`
|
|
}
|
|
|
|
func convertTools(tools []api.ToolSpec) []Tool {
|
|
anthropicTools := make([]Tool, len(tools))
|
|
for i, tool := range tools {
|
|
properties := make(map[string]Property)
|
|
for _, param := range tool.Parameters {
|
|
properties[param.Name] = Property{
|
|
Type: param.Type,
|
|
Description: param.Description,
|
|
Enum: param.Enum,
|
|
}
|
|
}
|
|
|
|
var required []string
|
|
for _, param := range tool.Parameters {
|
|
if param.Required {
|
|
required = append(required, param.Name)
|
|
}
|
|
}
|
|
|
|
anthropicTools[i] = Tool{
|
|
Name: tool.Name,
|
|
Description: tool.Description,
|
|
InputSchema: InputSchema{
|
|
Type: "object",
|
|
Properties: properties,
|
|
Required: required,
|
|
},
|
|
}
|
|
}
|
|
return anthropicTools
|
|
}
|
|
|
|
func createChatCompletionRequest(
|
|
params provider.RequestParameters,
|
|
messages []api.Message,
|
|
) (string, ChatCompletionRequest) {
|
|
requestMessages := make([]ChatCompletionMessage, 0, len(messages))
|
|
var systemMessage string
|
|
|
|
for _, m := range messages {
|
|
if m.Role == api.MessageRoleSystem {
|
|
systemMessage = m.Content
|
|
continue
|
|
}
|
|
|
|
var content interface{}
|
|
role := string(m.Role)
|
|
|
|
switch m.Role {
|
|
case api.MessageRoleToolCall:
|
|
role = "assistant"
|
|
contentBlocks := make([]map[string]interface{}, 0)
|
|
if m.Content != "" {
|
|
contentBlocks = append(contentBlocks, map[string]interface{}{
|
|
"type": "text",
|
|
"text": m.Content,
|
|
})
|
|
}
|
|
for _, toolCall := range m.ToolCalls {
|
|
contentBlocks = append(contentBlocks, map[string]interface{}{
|
|
"type": "tool_use",
|
|
"id": toolCall.ID,
|
|
"name": toolCall.Name,
|
|
"input": toolCall.Parameters,
|
|
})
|
|
}
|
|
content = contentBlocks
|
|
|
|
case api.MessageRoleToolResult:
|
|
role = "user"
|
|
contentBlocks := make([]map[string]interface{}, 0)
|
|
for _, result := range m.ToolResults {
|
|
contentBlock := map[string]interface{}{
|
|
"type": "tool_result",
|
|
"tool_use_id": result.ToolCallID,
|
|
"content": result.Result,
|
|
}
|
|
contentBlocks = append(contentBlocks, contentBlock)
|
|
}
|
|
content = contentBlocks
|
|
|
|
default:
|
|
content = m.Content
|
|
}
|
|
|
|
requestMessages = append(requestMessages, ChatCompletionMessage{
|
|
Role: role,
|
|
Content: content,
|
|
})
|
|
}
|
|
|
|
request := ChatCompletionRequest{
|
|
Model: params.Model,
|
|
Messages: requestMessages,
|
|
System: systemMessage,
|
|
MaxTokens: params.MaxTokens,
|
|
Temperature: params.Temperature,
|
|
}
|
|
|
|
if len(params.Toolbox) > 0 {
|
|
request.Tools = convertTools(params.Toolbox)
|
|
}
|
|
|
|
var prefill string
|
|
if len(messages) > 0 && messages[len(messages)-1].Role == api.MessageRoleAssistant {
|
|
// Prompting on an assitant message, use its content as prefill
|
|
prefill = messages[len(messages)-1].Content
|
|
}
|
|
|
|
return prefill, request
|
|
}
|
|
|
|
func (c *AnthropicClient) sendRequest(ctx context.Context, r ChatCompletionRequest) (*http.Response, error) {
|
|
jsonData, err := json.Marshal(r)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to marshal request: %w", err)
|
|
}
|
|
|
|
req, err := http.NewRequestWithContext(ctx, "POST", c.BaseURL+"/v1/messages", bytes.NewBuffer(jsonData))
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to create HTTP request: %w", err)
|
|
}
|
|
|
|
req.Header.Set("x-api-key", c.APIKey)
|
|
req.Header.Set("anthropic-version", ANTHROPIC_VERSION)
|
|
req.Header.Set("content-type", "application/json")
|
|
|
|
client := &http.Client{}
|
|
resp, err := client.Do(req)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
if resp.StatusCode != 200 {
|
|
bytes, _ := io.ReadAll(resp.Body)
|
|
return resp, fmt.Errorf("%v", string(bytes))
|
|
}
|
|
|
|
return resp, err
|
|
}
|
|
|
|
func (c *AnthropicClient) CreateChatCompletion(
|
|
ctx context.Context,
|
|
params provider.RequestParameters,
|
|
messages []api.Message,
|
|
) (*api.Message, error) {
|
|
if len(messages) == 0 {
|
|
return nil, fmt.Errorf("can't create completion from no messages")
|
|
}
|
|
|
|
_, req := createChatCompletionRequest(params, messages)
|
|
req.Stream = false
|
|
|
|
resp, err := c.sendRequest(ctx, req)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to send request: %w", err)
|
|
}
|
|
defer resp.Body.Close()
|
|
|
|
var completionResp ChatCompletionResponse
|
|
err = json.NewDecoder(resp.Body).Decode(&completionResp)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to decode response: %w", err)
|
|
}
|
|
|
|
return convertResponseToMessage(completionResp)
|
|
}
|
|
|
|
func (c *AnthropicClient) CreateChatCompletionStream(
|
|
ctx context.Context,
|
|
params provider.RequestParameters,
|
|
messages []api.Message,
|
|
output chan<- provider.Chunk,
|
|
) (*api.Message, error) {
|
|
if len(messages) == 0 {
|
|
return nil, fmt.Errorf("can't create completion from no messages")
|
|
}
|
|
|
|
prefill, req := createChatCompletionRequest(params, messages)
|
|
req.Stream = true
|
|
|
|
resp, err := c.sendRequest(ctx, req)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to send request: %w", err)
|
|
}
|
|
defer resp.Body.Close()
|
|
|
|
contentBlocks := make(map[int]*ContentBlock)
|
|
var finalMessage *ChatCompletionResponse
|
|
|
|
var firstChunkReceived bool
|
|
|
|
reader := bufio.NewReader(resp.Body)
|
|
for {
|
|
select {
|
|
case <-ctx.Done():
|
|
return nil, ctx.Err()
|
|
default:
|
|
line, err := reader.ReadBytes('\n')
|
|
if err != nil {
|
|
if err == io.EOF {
|
|
break
|
|
}
|
|
return nil, fmt.Errorf("error reading stream: %w", err)
|
|
}
|
|
|
|
line = bytes.TrimSpace(line)
|
|
if len(line) == 0 || !bytes.HasPrefix(line, []byte("data: ")) {
|
|
continue
|
|
}
|
|
|
|
line = bytes.TrimPrefix(line, []byte("data: "))
|
|
|
|
var streamEvent StreamEvent
|
|
err = json.Unmarshal(line, &streamEvent)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to unmarshal stream event: %w", err)
|
|
}
|
|
|
|
switch streamEvent.Type {
|
|
case "message_start":
|
|
finalMessage = &ChatCompletionResponse{}
|
|
err = json.Unmarshal(line, &struct {
|
|
Message *ChatCompletionResponse `json:"message"`
|
|
}{Message: finalMessage})
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to unmarshal message_start: %w", err)
|
|
}
|
|
case "content_block_start":
|
|
var contentBlockStart struct {
|
|
Index int `json:"index"`
|
|
ContentBlock ContentBlock `json:"content_block"`
|
|
}
|
|
err = json.Unmarshal(line, &contentBlockStart)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to unmarshal content_block_start: %w", err)
|
|
}
|
|
|
|
contentBlocks[contentBlockStart.Index] = &contentBlockStart.ContentBlock
|
|
case "content_block_delta":
|
|
if streamEvent.Index >= len(contentBlocks) {
|
|
return nil, fmt.Errorf("received delta for non-existent content block index: %d", streamEvent.Index)
|
|
}
|
|
|
|
block := contentBlocks[streamEvent.Index]
|
|
delta, ok := streamEvent.Delta.(map[string]interface{})
|
|
if !ok {
|
|
return nil, fmt.Errorf("unexpected delta type: %T", streamEvent.Delta)
|
|
}
|
|
|
|
deltaType, ok := delta["type"].(string)
|
|
if !ok {
|
|
return nil, fmt.Errorf("delta missing type field")
|
|
}
|
|
|
|
switch deltaType {
|
|
case "text_delta":
|
|
if text, ok := delta["text"].(string); ok {
|
|
if !firstChunkReceived {
|
|
if prefill == "" {
|
|
// if there is no prefil, ensure we trim leading whitespace
|
|
text = strings.TrimSpace(text)
|
|
}
|
|
firstChunkReceived = true
|
|
}
|
|
block.Text += text
|
|
output <- provider.Chunk{
|
|
Content: text,
|
|
// rough, anthropic performs some chunking
|
|
TokenCount: uint(len(strings.Split(text, " "))),
|
|
}
|
|
}
|
|
case "input_json_delta":
|
|
if block.Type != "tool_use" {
|
|
return nil, fmt.Errorf("received input_json_delta for non-tool_use block")
|
|
}
|
|
if partialJSON, ok := delta["partial_json"].(string); ok {
|
|
block.partialJsonAccumulator += partialJSON
|
|
}
|
|
}
|
|
|
|
case "content_block_stop":
|
|
if streamEvent.Index >= len(contentBlocks) {
|
|
return nil, fmt.Errorf("received stop for non-existent content block index: %d", streamEvent.Index)
|
|
}
|
|
|
|
block := contentBlocks[streamEvent.Index]
|
|
if block.Type == "tool_use" && block.partialJsonAccumulator != "" {
|
|
var inputData map[string]interface{}
|
|
err := json.Unmarshal([]byte(block.partialJsonAccumulator), &inputData)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to unmarshal accumulated JSON for tool use: %w", err)
|
|
}
|
|
block.Input = inputData
|
|
}
|
|
case "message_delta":
|
|
if finalMessage == nil {
|
|
return nil, fmt.Errorf("received message_delta before message_start")
|
|
}
|
|
delta, ok := streamEvent.Delta.(map[string]interface{})
|
|
if !ok {
|
|
return nil, fmt.Errorf("unexpected delta type in message_delta: %T", streamEvent.Delta)
|
|
}
|
|
if stopReason, ok := delta["stop_reason"].(string); ok {
|
|
finalMessage.StopReason = stopReason
|
|
}
|
|
|
|
case "message_stop":
|
|
// End of the stream
|
|
goto END_STREAM
|
|
|
|
case "error":
|
|
return nil, fmt.Errorf("received error event: %v", streamEvent.Message)
|
|
|
|
default:
|
|
// Ignore unknown event types
|
|
}
|
|
}
|
|
}
|
|
|
|
END_STREAM:
|
|
if finalMessage == nil {
|
|
return nil, fmt.Errorf("no final message received")
|
|
}
|
|
|
|
finalMessage.Content = make([]ContentBlock, len(contentBlocks))
|
|
for _, v := range contentBlocks {
|
|
finalMessage.Content = append(finalMessage.Content, *v)
|
|
}
|
|
|
|
return convertResponseToMessage(*finalMessage)
|
|
}
|
|
|
|
func convertResponseToMessage(resp ChatCompletionResponse) (*api.Message, error) {
|
|
content := strings.Builder{}
|
|
var toolCalls []api.ToolCall
|
|
|
|
for _, block := range resp.Content {
|
|
switch block.Type {
|
|
case "text":
|
|
content.WriteString(block.Text)
|
|
case "tool_use":
|
|
parameters, ok := block.Input.(map[string]interface{})
|
|
if !ok {
|
|
return nil, fmt.Errorf("unexpected type for tool call parameters: %T", block.Input)
|
|
}
|
|
toolCalls = append(toolCalls, api.ToolCall{
|
|
ID: block.ID,
|
|
Name: block.Name,
|
|
Parameters: parameters,
|
|
})
|
|
}
|
|
}
|
|
|
|
message := &api.Message{
|
|
Role: api.MessageRoleAssistant,
|
|
Content: content.String(),
|
|
ToolCalls: toolCalls,
|
|
}
|
|
|
|
if len(toolCalls) > 0 {
|
|
message.Role = api.MessageRoleToolCall
|
|
}
|
|
|
|
return message, nil
|
|
}
|