package anthropic import ( "bufio" "bytes" "context" "encoding/json" "fmt" "io" "net/http" "strings" "git.mlow.ca/mlow/lmcli/pkg/api" ) const ANTHROPIC_VERSION = "2023-06-01" type AnthropicClient struct { APIKey string BaseURL string } type ChatCompletionMessage struct { Role string `json:"role"` Content interface{} `json:"content"` } type Tool struct { Name string `json:"name"` Description string `json:"description"` InputSchema InputSchema `json:"input_schema"` } type InputSchema struct { Type string `json:"type"` Properties map[string]Property `json:"properties"` Required []string `json:"required"` } type Property struct { Type string `json:"type"` Description string `json:"description"` Enum []string `json:"enum,omitempty"` } type ChatCompletionRequest struct { Model string `json:"model"` Messages []ChatCompletionMessage `json:"messages"` System string `json:"system,omitempty"` Tools []Tool `json:"tools,omitempty"` MaxTokens int `json:"max_tokens"` Temperature float32 `json:"temperature,omitempty"` Stream bool `json:"stream"` } type ContentBlock struct { Type string `json:"type"` Text string `json:"text,omitempty"` ID string `json:"id,omitempty"` Name string `json:"name,omitempty"` Input interface{} `json:"input,omitempty"` partialJsonAccumulator string } type ChatCompletionResponse struct { ID string `json:"id"` Type string `json:"type"` Role string `json:"role"` Model string `json:"model"` Content []ContentBlock `json:"content"` StopReason string `json:"stop_reason"` Usage Usage `json:"usage"` } type Usage struct { InputTokens int `json:"input_tokens"` OutputTokens int `json:"output_tokens"` } type StreamEvent struct { Type string `json:"type"` Message interface{} `json:"message,omitempty"` Index int `json:"index,omitempty"` Delta interface{} `json:"delta,omitempty"` } func convertTools(tools []api.ToolSpec) []Tool { anthropicTools := make([]Tool, len(tools)) for i, tool := range tools { properties := make(map[string]Property) for _, param := range tool.Parameters { properties[param.Name] = Property{ Type: param.Type, Description: param.Description, Enum: param.Enum, } } var required []string for _, param := range tool.Parameters { if param.Required { required = append(required, param.Name) } } anthropicTools[i] = Tool{ Name: tool.Name, Description: tool.Description, InputSchema: InputSchema{ Type: "object", Properties: properties, Required: required, }, } } return anthropicTools } func createChatCompletionRequest( params api.RequestParameters, messages []api.Message, ) (string, ChatCompletionRequest) { requestMessages := make([]ChatCompletionMessage, 0, len(messages)) var systemMessage string for _, m := range messages { if m.Role == api.MessageRoleSystem { systemMessage = m.Content continue } var content interface{} role := string(m.Role) switch m.Role { case api.MessageRoleToolCall: role = "assistant" contentBlocks := make([]map[string]interface{}, 0) if m.Content != "" { contentBlocks = append(contentBlocks, map[string]interface{}{ "type": "text", "text": m.Content, }) } for _, toolCall := range m.ToolCalls { contentBlocks = append(contentBlocks, map[string]interface{}{ "type": "tool_use", "id": toolCall.ID, "name": toolCall.Name, "input": toolCall.Parameters, }) } content = contentBlocks case api.MessageRoleToolResult: role = "user" contentBlocks := make([]map[string]interface{}, 0) for _, result := range m.ToolResults { contentBlock := map[string]interface{}{ "type": "tool_result", "tool_use_id": result.ToolCallID, "content": result.Result, } contentBlocks = append(contentBlocks, contentBlock) } content = contentBlocks default: content = m.Content } requestMessages = append(requestMessages, ChatCompletionMessage{ Role: role, Content: content, }) } request := ChatCompletionRequest{ Model: params.Model, Messages: requestMessages, System: systemMessage, MaxTokens: params.MaxTokens, Temperature: params.Temperature, } if len(params.Toolbox) > 0 { request.Tools = convertTools(params.Toolbox) } var prefill string if api.IsAssistantContinuation(messages) { prefill = messages[len(messages)-1].Content } return prefill, request } func (c *AnthropicClient) sendRequest(ctx context.Context, r ChatCompletionRequest) (*http.Response, error) { jsonData, err := json.Marshal(r) if err != nil { return nil, fmt.Errorf("failed to marshal request: %w", err) } req, err := http.NewRequestWithContext(ctx, "POST", c.BaseURL+"/v1/messages", bytes.NewBuffer(jsonData)) if err != nil { return nil, fmt.Errorf("failed to create HTTP request: %w", err) } req.Header.Set("x-api-key", c.APIKey) req.Header.Set("anthropic-version", ANTHROPIC_VERSION) req.Header.Set("content-type", "application/json") client := &http.Client{} resp, err := client.Do(req) if err != nil { return nil, err } if resp.StatusCode != 200 { bytes, _ := io.ReadAll(resp.Body) return resp, fmt.Errorf("%v", string(bytes)) } return resp, err } func (c *AnthropicClient) CreateChatCompletion( ctx context.Context, params api.RequestParameters, messages []api.Message, ) (*api.Message, error) { if len(messages) == 0 { return nil, fmt.Errorf("can't create completion from no messages") } _, req := createChatCompletionRequest(params, messages) req.Stream = false resp, err := c.sendRequest(ctx, req) if err != nil { return nil, fmt.Errorf("failed to send request: %w", err) } defer resp.Body.Close() var completionResp ChatCompletionResponse err = json.NewDecoder(resp.Body).Decode(&completionResp) if err != nil { return nil, fmt.Errorf("failed to decode response: %w", err) } return convertResponseToMessage(completionResp) } func (c *AnthropicClient) CreateChatCompletionStream( ctx context.Context, params api.RequestParameters, messages []api.Message, output chan<- api.Chunk, ) (*api.Message, error) { if len(messages) == 0 { return nil, fmt.Errorf("can't create completion from no messages") } prefill, req := createChatCompletionRequest(params, messages) req.Stream = true resp, err := c.sendRequest(ctx, req) if err != nil { return nil, fmt.Errorf("failed to send request: %w", err) } defer resp.Body.Close() contentBlocks := make(map[int]*ContentBlock) var finalMessage *ChatCompletionResponse var firstChunkReceived bool reader := bufio.NewReader(resp.Body) for { select { case <-ctx.Done(): return nil, ctx.Err() default: line, err := reader.ReadBytes('\n') if err != nil { if err == io.EOF { break } return nil, fmt.Errorf("error reading stream: %w", err) } line = bytes.TrimSpace(line) if len(line) == 0 || !bytes.HasPrefix(line, []byte("data: ")) { continue } line = bytes.TrimPrefix(line, []byte("data: ")) var streamEvent StreamEvent err = json.Unmarshal(line, &streamEvent) if err != nil { return nil, fmt.Errorf("failed to unmarshal stream event: %w", err) } switch streamEvent.Type { case "message_start": finalMessage = &ChatCompletionResponse{} err = json.Unmarshal(line, &struct { Message *ChatCompletionResponse `json:"message"` }{Message: finalMessage}) if err != nil { return nil, fmt.Errorf("failed to unmarshal message_start: %w", err) } case "content_block_start": var contentBlockStart struct { Index int `json:"index"` ContentBlock ContentBlock `json:"content_block"` } err = json.Unmarshal(line, &contentBlockStart) if err != nil { return nil, fmt.Errorf("failed to unmarshal content_block_start: %w", err) } contentBlocks[contentBlockStart.Index] = &contentBlockStart.ContentBlock case "content_block_delta": if streamEvent.Index >= len(contentBlocks) { return nil, fmt.Errorf("received delta for non-existent content block index: %d", streamEvent.Index) } block := contentBlocks[streamEvent.Index] delta, ok := streamEvent.Delta.(map[string]interface{}) if !ok { return nil, fmt.Errorf("unexpected delta type: %T", streamEvent.Delta) } deltaType, ok := delta["type"].(string) if !ok { return nil, fmt.Errorf("delta missing type field") } switch deltaType { case "text_delta": if text, ok := delta["text"].(string); ok { if !firstChunkReceived { if prefill == "" { // if there is no prefil, ensure we trim leading whitespace text = strings.TrimSpace(text) } firstChunkReceived = true } block.Text += text output <- api.Chunk{ Content: text, TokenCount: 1, } } case "input_json_delta": if block.Type != "tool_use" { return nil, fmt.Errorf("received input_json_delta for non-tool_use block") } if partialJSON, ok := delta["partial_json"].(string); ok { block.partialJsonAccumulator += partialJSON } } case "content_block_stop": if streamEvent.Index >= len(contentBlocks) { return nil, fmt.Errorf("received stop for non-existent content block index: %d", streamEvent.Index) } block := contentBlocks[streamEvent.Index] if block.Type == "tool_use" && block.partialJsonAccumulator != "" { var inputData map[string]interface{} err := json.Unmarshal([]byte(block.partialJsonAccumulator), &inputData) if err != nil { return nil, fmt.Errorf("failed to unmarshal accumulated JSON for tool use: %w", err) } block.Input = inputData } case "message_delta": if finalMessage == nil { return nil, fmt.Errorf("received message_delta before message_start") } delta, ok := streamEvent.Delta.(map[string]interface{}) if !ok { return nil, fmt.Errorf("unexpected delta type in message_delta: %T", streamEvent.Delta) } if stopReason, ok := delta["stop_reason"].(string); ok { finalMessage.StopReason = stopReason } case "message_stop": // End of the stream goto END_STREAM case "error": return nil, fmt.Errorf("received error event: %v", streamEvent.Message) default: // Ignore unknown event types } } } END_STREAM: if finalMessage == nil { return nil, fmt.Errorf("no final message received") } finalMessage.Content = make([]ContentBlock, len(contentBlocks)) for _, v := range contentBlocks { finalMessage.Content = append(finalMessage.Content, *v) } return convertResponseToMessage(*finalMessage) } func convertResponseToMessage(resp ChatCompletionResponse) (*api.Message, error) { content := strings.Builder{} var toolCalls []api.ToolCall for _, block := range resp.Content { switch block.Type { case "text": content.WriteString(block.Text) case "tool_use": parameters, ok := block.Input.(map[string]interface{}) if !ok { return nil, fmt.Errorf("unexpected type for tool call parameters: %T", block.Input) } toolCalls = append(toolCalls, api.ToolCall{ ID: block.ID, Name: block.Name, Parameters: parameters, }) } } message := &api.Message{ Role: api.MessageRoleAssistant, Content: content.String(), ToolCalls: toolCalls, } if len(toolCalls) > 0 { message.Role = api.MessageRoleToolCall } return message, nil }