package anthropic import ( "bufio" "bytes" "context" "encoding/json" "encoding/xml" "fmt" "net/http" "strings" "git.mlow.ca/mlow/lmcli/pkg/lmcli/model" "git.mlow.ca/mlow/lmcli/pkg/lmcli/provider" "git.mlow.ca/mlow/lmcli/pkg/lmcli/tools" ) type AnthropicClient struct { APIKey string } type Message struct { Role string `json:"role"` Content string `json:"content"` } type Request struct { Model string `json:"model"` Messages []Message `json:"messages"` System string `json:"system,omitempty"` MaxTokens int `json:"max_tokens,omitempty"` StopSequences []string `json:"stop_sequences,omitempty"` Stream bool `json:"stream,omitempty"` Temperature float32 `json:"temperature,omitempty"` //TopP float32 `json:"top_p,omitempty"` //TopK float32 `json:"top_k,omitempty"` } type OriginalContent struct { Type string `json:"type"` Text string `json:"text"` } type Response struct { Id string `json:"id"` Type string `json:"type"` Role string `json:"role"` Content []OriginalContent `json:"content"` } const FUNCTION_STOP_SEQUENCE = "" func buildRequest(params model.RequestParameters, messages []model.Message) Request { requestBody := Request{ Model: params.Model, Messages: make([]Message, len(messages)), System: params.SystemPrompt, MaxTokens: params.MaxTokens, Temperature: params.Temperature, Stream: false, StopSequences: []string{ FUNCTION_STOP_SEQUENCE, "\n\nHuman:", }, } startIdx := 0 if len(messages) > 0 && messages[0].Role == model.MessageRoleSystem { requestBody.System = messages[0].Content requestBody.Messages = requestBody.Messages[1:] startIdx = 1 } if len(params.ToolBag) > 0 { if len(requestBody.System) > 0 { // add a divider between existing system prompt and tools requestBody.System += "\n\n---\n\n" } requestBody.System += buildToolsSystemPrompt(params.ToolBag) } for i, msg := range messages[startIdx:] { message := &requestBody.Messages[i] switch msg.Role { case model.MessageRoleToolCall: message.Role = "assistant" if msg.Content != "" { message.Content = msg.Content } xmlFuncCalls := convertToolCallsToXMLFunctionCalls(msg.ToolCalls) xmlString, err := xmlFuncCalls.XMLString() if err != nil { panic("Could not serialize []ToolCall to XMLFunctionCall") } if len(message.Content) > 0 { message.Content += fmt.Sprintf("\n\n%s", xmlString) } else { message.Content = xmlString } case model.MessageRoleToolResult: xmlFuncResults := convertToolResultsToXMLFunctionResult(msg.ToolResults) xmlString, err := xmlFuncResults.XMLString() if err != nil { panic("Could not serialize []ToolResult to XMLFunctionResults") } message.Role = "user" message.Content = xmlString default: message.Role = string(msg.Role) message.Content = msg.Content } } return requestBody } func sendRequest(ctx context.Context, c *AnthropicClient, r Request) (*http.Response, error) { url := "https://api.anthropic.com/v1/messages" jsonBody, err := json.Marshal(r) if err != nil { return nil, fmt.Errorf("failed to marshal request body: %v", err) } req, err := http.NewRequestWithContext(ctx, "POST", url, bytes.NewBuffer(jsonBody)) if err != nil { return nil, fmt.Errorf("failed to create HTTP request: %v", err) } req.Header.Set("x-api-key", c.APIKey) req.Header.Set("anthropic-version", "2023-06-01") req.Header.Set("content-type", "application/json") client := &http.Client{} resp, err := client.Do(req) if err != nil { return nil, fmt.Errorf("failed to send HTTP request: %v", err) } return resp, nil } func (c *AnthropicClient) CreateChatCompletion( ctx context.Context, params model.RequestParameters, messages []model.Message, callback provider.ReplyCallback, ) (string, error) { request := buildRequest(params, messages) resp, err := sendRequest(ctx, c, request) if err != nil { return "", err } defer resp.Body.Close() var response Response err = json.NewDecoder(resp.Body).Decode(&response) if err != nil { return "", fmt.Errorf("failed to decode response: %v", err) } sb := strings.Builder{} for _, content := range response.Content { var reply model.Message switch content.Type { case "text": reply = model.Message{ Role: model.MessageRoleAssistant, Content: content.Text, } sb.WriteString(reply.Content) default: return "", fmt.Errorf("unsupported message type: %s", content.Type) } if callback != nil { callback(reply) } } return sb.String(), nil } func (c *AnthropicClient) CreateChatCompletionStream( ctx context.Context, params model.RequestParameters, messages []model.Message, callback provider.ReplyCallback, output chan<- string, ) (string, error) { request := buildRequest(params, messages) request.Stream = true resp, err := sendRequest(ctx, c, request) if err != nil { return "", err } defer resp.Body.Close() scanner := bufio.NewScanner(resp.Body) sb := strings.Builder{} isToolCall := false for scanner.Scan() { line := scanner.Text() line = strings.TrimSpace(line) if len(line) == 0 { continue } if line[0] == '{' { var event map[string]interface{} err := json.Unmarshal([]byte(line), &event) if err != nil { return "", fmt.Errorf("failed to unmarshal event data '%s': %v", line, err) } eventType, ok := event["type"].(string) if !ok { return "", fmt.Errorf("invalid event: %s", line) } switch eventType { case "error": return sb.String(), fmt.Errorf("an error occurred: %s", event["error"]) default: return sb.String(), fmt.Errorf("unknown event type: %s", eventType) } } else if strings.HasPrefix(line, "data:") { data := strings.TrimSpace(strings.TrimPrefix(line, "data:")) var event map[string]interface{} err := json.Unmarshal([]byte(data), &event) if err != nil { return "", fmt.Errorf("failed to unmarshal event data: %v", err) } eventType, ok := event["type"].(string) if !ok { return "", fmt.Errorf("invalid event type") } switch eventType { case "message_start": // noop case "ping": // write an empty string to signal start of text output <- "" case "content_block_start": // ignore? case "content_block_delta": delta, ok := event["delta"].(map[string]interface{}) if !ok { return "", fmt.Errorf("invalid content block delta") } text, ok := delta["text"].(string) if !ok { return "", fmt.Errorf("invalid text delta") } sb.WriteString(text) output <- text case "content_block_stop": // ignore? case "message_delta": delta, ok := event["delta"].(map[string]interface{}) if !ok { return "", fmt.Errorf("invalid message delta") } stopReason, ok := delta["stop_reason"].(string) if ok && stopReason == "stop_sequence" { stopSequence, ok := delta["stop_sequence"].(string) if ok && stopSequence == FUNCTION_STOP_SEQUENCE { content := sb.String() start := strings.Index(content, "") if start == -1 { return content, fmt.Errorf("reached stop sequence but no opening tag found") } isToolCall = true funcCallXml := content[start:] funcCallXml += FUNCTION_STOP_SEQUENCE sb.WriteString(FUNCTION_STOP_SEQUENCE) output <- FUNCTION_STOP_SEQUENCE // Extract function calls var functionCalls XMLFunctionCalls err := xml.Unmarshal([]byte(sb.String()), &functionCalls) if err != nil { return "", fmt.Errorf("failed to unmarshal function_calls: %v", err) } // Execute function calls toolCall := model.Message{ Role: model.MessageRoleToolCall, // xml stripped from content Content: content[:start], ToolCalls: convertXMLFunctionCallsToToolCalls(functionCalls), } toolResults, err := tools.ExecuteToolCalls(toolCall.ToolCalls, params.ToolBag) if err != nil { return "", err } toolReply := model.Message{ Role: model.MessageRoleToolResult, ToolResults: toolResults, } if callback != nil { callback(toolCall) callback(toolReply) } // Recurse into CreateChatCompletionStream with the tool call replies // added to the original messages messages = append(append(messages, toolCall), toolReply) return c.CreateChatCompletionStream(ctx, params, messages, callback, output) } } case "message_stop": // return the completed message if callback != nil { if !isToolCall { callback(model.Message{ Role: model.MessageRoleAssistant, Content: sb.String(), }) } } return sb.String(), nil case "error": return sb.String(), fmt.Errorf("an error occurred: %s", event["error"]) default: fmt.Printf("\nUnrecognized event: %s\n", data) } } } if err := scanner.Err(); err != nil { return "", fmt.Errorf("failed to read response body: %v", err) } return "", fmt.Errorf("unexpected end of stream") }