lmcli/pkg/api/provider/anthropic/anthropic.go

451 lines
12 KiB
Go
Raw Normal View History

package anthropic
import (
"bufio"
"bytes"
"context"
"encoding/json"
"fmt"
"io"
"net/http"
"strings"
"git.mlow.ca/mlow/lmcli/pkg/api"
)
const ANTHROPIC_VERSION = "2023-06-01"
type AnthropicClient struct {
APIKey string
BaseURL string
}
type ChatCompletionMessage struct {
Role string `json:"role"`
Content interface{} `json:"content"`
}
type Tool struct {
Name string `json:"name"`
Description string `json:"description"`
InputSchema InputSchema `json:"input_schema"`
}
type InputSchema struct {
Type string `json:"type"`
Properties map[string]Property `json:"properties"`
Required []string `json:"required"`
}
type Property struct {
Type string `json:"type"`
Description string `json:"description"`
Enum []string `json:"enum,omitempty"`
}
type ChatCompletionRequest struct {
Model string `json:"model"`
Messages []ChatCompletionMessage `json:"messages"`
System string `json:"system,omitempty"`
Tools []Tool `json:"tools,omitempty"`
MaxTokens int `json:"max_tokens"`
Temperature float32 `json:"temperature,omitempty"`
Stream bool `json:"stream"`
}
type ContentBlock struct {
Type string `json:"type"`
Text string `json:"text,omitempty"`
ID string `json:"id,omitempty"`
Name string `json:"name,omitempty"`
Input interface{} `json:"input,omitempty"`
partialJsonAccumulator string
}
type ChatCompletionResponse struct {
ID string `json:"id"`
Type string `json:"type"`
Role string `json:"role"`
Model string `json:"model"`
Content []ContentBlock `json:"content"`
StopReason string `json:"stop_reason"`
Usage Usage `json:"usage"`
}
type Usage struct {
InputTokens int `json:"input_tokens"`
OutputTokens int `json:"output_tokens"`
}
type StreamEvent struct {
Type string `json:"type"`
Message interface{} `json:"message,omitempty"`
Index int `json:"index,omitempty"`
Delta interface{} `json:"delta,omitempty"`
}
func convertTools(tools []api.ToolSpec) []Tool {
anthropicTools := make([]Tool, len(tools))
for i, tool := range tools {
properties := make(map[string]Property)
for _, param := range tool.Parameters {
properties[param.Name] = Property{
Type: param.Type,
Description: param.Description,
Enum: param.Enum,
}
}
var required []string
for _, param := range tool.Parameters {
if param.Required {
required = append(required, param.Name)
}
}
anthropicTools[i] = Tool{
Name: tool.Name,
Description: tool.Description,
InputSchema: InputSchema{
Type: "object",
Properties: properties,
Required: required,
},
}
}
return anthropicTools
}
func createChatCompletionRequest(
params api.RequestParameters,
messages []api.Message,
) (string, ChatCompletionRequest) {
requestMessages := make([]ChatCompletionMessage, 0, len(messages))
var systemMessage string
for _, m := range messages {
if m.Role == api.MessageRoleSystem {
systemMessage = m.Content
continue
}
var content interface{}
role := string(m.Role)
switch m.Role {
case api.MessageRoleToolCall:
role = "assistant"
contentBlocks := make([]map[string]interface{}, 0)
if m.Content != "" {
contentBlocks = append(contentBlocks, map[string]interface{}{
"type": "text",
"text": m.Content,
})
}
for _, toolCall := range m.ToolCalls {
contentBlocks = append(contentBlocks, map[string]interface{}{
"type": "tool_use",
"id": toolCall.ID,
"name": toolCall.Name,
"input": toolCall.Parameters,
})
}
content = contentBlocks
case api.MessageRoleToolResult:
role = "user"
contentBlocks := make([]map[string]interface{}, 0)
for _, result := range m.ToolResults {
contentBlock := map[string]interface{}{
"type": "tool_result",
"tool_use_id": result.ToolCallID,
"content": result.Result,
}
contentBlocks = append(contentBlocks, contentBlock)
}
content = contentBlocks
default:
content = m.Content
}
requestMessages = append(requestMessages, ChatCompletionMessage{
Role: role,
Content: content,
})
}
request := ChatCompletionRequest{
Model: params.Model,
Messages: requestMessages,
System: systemMessage,
MaxTokens: params.MaxTokens,
Temperature: params.Temperature,
}
if len(params.ToolBag) > 0 {
request.Tools = convertTools(params.ToolBag)
}
var prefill string
if api.IsAssistantContinuation(messages) {
prefill = messages[len(messages)-1].Content
}
return prefill, request
}
func (c *AnthropicClient) sendRequest(ctx context.Context, r ChatCompletionRequest) (*http.Response, error) {
jsonData, err := json.Marshal(r)
if err != nil {
return nil, fmt.Errorf("failed to marshal request: %w", err)
}
req, err := http.NewRequestWithContext(ctx, "POST", c.BaseURL+"/v1/messages", bytes.NewBuffer(jsonData))
if err != nil {
return nil, fmt.Errorf("failed to create HTTP request: %w", err)
}
req.Header.Set("x-api-key", c.APIKey)
req.Header.Set("anthropic-version", ANTHROPIC_VERSION)
req.Header.Set("content-type", "application/json")
client := &http.Client{}
resp, err := client.Do(req)
if err != nil {
return nil, err
}
if resp.StatusCode != 200 {
bytes, _ := io.ReadAll(resp.Body)
return resp, fmt.Errorf("%v", string(bytes))
}
return resp, err
}
func (c *AnthropicClient) CreateChatCompletion(
ctx context.Context,
params api.RequestParameters,
messages []api.Message,
) (*api.Message, error) {
if len(messages) == 0 {
return nil, fmt.Errorf("can't create completion from no messages")
}
_, req := createChatCompletionRequest(params, messages)
req.Stream = false
resp, err := c.sendRequest(ctx, req)
if err != nil {
return nil, fmt.Errorf("failed to send request: %w", err)
}
defer resp.Body.Close()
var completionResp ChatCompletionResponse
err = json.NewDecoder(resp.Body).Decode(&completionResp)
if err != nil {
return nil, fmt.Errorf("failed to decode response: %w", err)
}
return convertResponseToMessage(completionResp)
}
func (c *AnthropicClient) CreateChatCompletionStream(
ctx context.Context,
params api.RequestParameters,
messages []api.Message,
output chan<- api.Chunk,
) (*api.Message, error) {
if len(messages) == 0 {
return nil, fmt.Errorf("can't create completion from no messages")
}
prefill, req := createChatCompletionRequest(params, messages)
req.Stream = true
resp, err := c.sendRequest(ctx, req)
if err != nil {
return nil, fmt.Errorf("failed to send request: %w", err)
}
defer resp.Body.Close()
contentBlocks := make(map[int]*ContentBlock)
var finalMessage *ChatCompletionResponse
var firstChunkReceived bool
reader := bufio.NewReader(resp.Body)
for {
select {
case <-ctx.Done():
return nil, ctx.Err()
default:
line, err := reader.ReadBytes('\n')
if err != nil {
if err == io.EOF {
break
}
return nil, fmt.Errorf("error reading stream: %w", err)
}
line = bytes.TrimSpace(line)
if len(line) == 0 || !bytes.HasPrefix(line, []byte("data: ")) {
continue
}
line = bytes.TrimPrefix(line, []byte("data: "))
var streamEvent StreamEvent
err = json.Unmarshal(line, &streamEvent)
if err != nil {
return nil, fmt.Errorf("failed to unmarshal stream event: %w", err)
}
switch streamEvent.Type {
case "message_start":
finalMessage = &ChatCompletionResponse{}
err = json.Unmarshal(line, &struct {
Message *ChatCompletionResponse `json:"message"`
}{Message: finalMessage})
if err != nil {
return nil, fmt.Errorf("failed to unmarshal message_start: %w", err)
}
case "content_block_start":
var contentBlockStart struct {
Index int `json:"index"`
ContentBlock ContentBlock `json:"content_block"`
}
err = json.Unmarshal(line, &contentBlockStart)
if err != nil {
return nil, fmt.Errorf("failed to unmarshal content_block_start: %w", err)
}
contentBlocks[contentBlockStart.Index] = &contentBlockStart.ContentBlock
case "content_block_delta":
if streamEvent.Index >= len(contentBlocks) {
return nil, fmt.Errorf("received delta for non-existent content block index: %d", streamEvent.Index)
}
block := contentBlocks[streamEvent.Index]
delta, ok := streamEvent.Delta.(map[string]interface{})
if !ok {
return nil, fmt.Errorf("unexpected delta type: %T", streamEvent.Delta)
}
deltaType, ok := delta["type"].(string)
if !ok {
return nil, fmt.Errorf("delta missing type field")
}
switch deltaType {
case "text_delta":
if text, ok := delta["text"].(string); ok {
if !firstChunkReceived {
if prefill == "" {
// if there is no prefil, ensure we trim leading whitespace
text = strings.TrimSpace(text)
}
firstChunkReceived = true
}
block.Text += text
output <- api.Chunk{
Content: text,
TokenCount: 1,
}
}
case "input_json_delta":
if block.Type != "tool_use" {
return nil, fmt.Errorf("received input_json_delta for non-tool_use block")
}
if partialJSON, ok := delta["partial_json"].(string); ok {
block.partialJsonAccumulator += partialJSON
}
}
case "content_block_stop":
if streamEvent.Index >= len(contentBlocks) {
return nil, fmt.Errorf("received stop for non-existent content block index: %d", streamEvent.Index)
}
block := contentBlocks[streamEvent.Index]
if block.Type == "tool_use" && block.partialJsonAccumulator != "" {
var inputData map[string]interface{}
err := json.Unmarshal([]byte(block.partialJsonAccumulator), &inputData)
if err != nil {
return nil, fmt.Errorf("failed to unmarshal accumulated JSON for tool use: %w", err)
}
block.Input = inputData
}
case "message_delta":
if finalMessage == nil {
return nil, fmt.Errorf("received message_delta before message_start")
}
delta, ok := streamEvent.Delta.(map[string]interface{})
if !ok {
return nil, fmt.Errorf("unexpected delta type in message_delta: %T", streamEvent.Delta)
}
if stopReason, ok := delta["stop_reason"].(string); ok {
finalMessage.StopReason = stopReason
}
case "message_stop":
// End of the stream
goto END_STREAM
case "error":
return nil, fmt.Errorf("received error event: %v", streamEvent.Message)
default:
// Ignore unknown event types
}
}
}
END_STREAM:
if finalMessage == nil {
return nil, fmt.Errorf("no final message received")
}
finalMessage.Content = make([]ContentBlock, len(contentBlocks))
for _, v := range contentBlocks {
finalMessage.Content = append(finalMessage.Content, *v)
}
return convertResponseToMessage(*finalMessage)
}
func convertResponseToMessage(resp ChatCompletionResponse) (*api.Message, error) {
content := strings.Builder{}
var toolCalls []api.ToolCall
for _, block := range resp.Content {
switch block.Type {
case "text":
content.WriteString(block.Text)
case "tool_use":
parameters, ok := block.Input.(map[string]interface{})
if !ok {
return nil, fmt.Errorf("unexpected type for tool call parameters: %T", block.Input)
}
toolCalls = append(toolCalls, api.ToolCall{
ID: block.ID,
Name: block.Name,
Parameters: parameters,
})
}
}
message := &api.Message{
Role: api.MessageRoleAssistant,
Content: content.String(),
ToolCalls: toolCalls,
}
if len(toolCalls) > 0 {
message.Role = api.MessageRoleToolCall
}
return message, nil
}