package anthropic import ( "bufio" "bytes" "context" "encoding/json" "encoding/xml" "fmt" "net/http" "strings" "git.mlow.ca/mlow/lmcli/pkg/api" "git.mlow.ca/mlow/lmcli/pkg/lmcli/model" "git.mlow.ca/mlow/lmcli/pkg/lmcli/tools" ) func buildRequest(params model.RequestParameters, messages []model.Message) Request { requestBody := Request{ Model: params.Model, Messages: make([]Message, len(messages)), MaxTokens: params.MaxTokens, Temperature: params.Temperature, Stream: false, StopSequences: []string{ FUNCTION_STOP_SEQUENCE, "\n\nHuman:", }, } startIdx := 0 if len(messages) > 0 && messages[0].Role == model.MessageRoleSystem { requestBody.System = messages[0].Content requestBody.Messages = requestBody.Messages[1:] startIdx = 1 } if len(params.ToolBag) > 0 { if len(requestBody.System) > 0 { // add a divider between existing system prompt and tools requestBody.System += "\n\n---\n\n" } requestBody.System += buildToolsSystemPrompt(params.ToolBag) } for i, msg := range messages[startIdx:] { message := &requestBody.Messages[i] switch msg.Role { case model.MessageRoleToolCall: message.Role = "assistant" if msg.Content != "" { message.Content = msg.Content } xmlFuncCalls := convertToolCallsToXMLFunctionCalls(msg.ToolCalls) xmlString, err := xmlFuncCalls.XMLString() if err != nil { panic("Could not serialize []ToolCall to XMLFunctionCall") } if len(message.Content) > 0 { message.Content += fmt.Sprintf("\n\n%s", xmlString) } else { message.Content = xmlString } case model.MessageRoleToolResult: xmlFuncResults := convertToolResultsToXMLFunctionResult(msg.ToolResults) xmlString, err := xmlFuncResults.XMLString() if err != nil { panic("Could not serialize []ToolResult to XMLFunctionResults") } message.Role = "user" message.Content = xmlString default: message.Role = string(msg.Role) message.Content = msg.Content } } return requestBody } func sendRequest(ctx context.Context, c *AnthropicClient, r Request) (*http.Response, error) { jsonBody, err := json.Marshal(r) if err != nil { return nil, fmt.Errorf("failed to marshal request body: %v", err) } req, err := http.NewRequestWithContext(ctx, "POST", c.BaseURL+"/messages", bytes.NewBuffer(jsonBody)) if err != nil { return nil, fmt.Errorf("failed to create HTTP request: %v", err) } req.Header.Set("x-api-key", c.APIKey) req.Header.Set("anthropic-version", "2023-06-01") req.Header.Set("content-type", "application/json") client := &http.Client{} resp, err := client.Do(req) if err != nil { return nil, fmt.Errorf("failed to send HTTP request: %v", err) } return resp, nil } func (c *AnthropicClient) CreateChatCompletion( ctx context.Context, params model.RequestParameters, messages []model.Message, callback api.ReplyCallback, ) (string, error) { if len(messages) == 0 { return "", fmt.Errorf("Can't create completion from no messages") } request := buildRequest(params, messages) resp, err := sendRequest(ctx, c, request) if err != nil { return "", err } defer resp.Body.Close() var response Response err = json.NewDecoder(resp.Body).Decode(&response) if err != nil { return "", fmt.Errorf("failed to decode response: %v", err) } sb := strings.Builder{} lastMessage := messages[len(messages)-1] if lastMessage.Role.IsAssistant() { // this is a continuation of a previous assistant reply, so we'll // include its contents in the final result sb.WriteString(lastMessage.Content) } for _, content := range response.Content { var reply model.Message switch content.Type { case "text": reply = model.Message{ Role: model.MessageRoleAssistant, Content: content.Text, } sb.WriteString(reply.Content) default: return "", fmt.Errorf("unsupported message type: %s", content.Type) } if callback != nil { callback(reply) } } return sb.String(), nil } func (c *AnthropicClient) CreateChatCompletionStream( ctx context.Context, params model.RequestParameters, messages []model.Message, callback api.ReplyCallback, output chan<- api.Chunk, ) (string, error) { if len(messages) == 0 { return "", fmt.Errorf("Can't create completion from no messages") } request := buildRequest(params, messages) request.Stream = true resp, err := sendRequest(ctx, c, request) if err != nil { return "", err } defer resp.Body.Close() sb := strings.Builder{} lastMessage := messages[len(messages)-1] continuation := false if messages[len(messages)-1].Role.IsAssistant() { // this is a continuation of a previous assistant reply, so we'll // include its contents in the final result sb.WriteString(lastMessage.Content) continuation = true } scanner := bufio.NewScanner(resp.Body) for scanner.Scan() { line := scanner.Text() line = strings.TrimSpace(line) if len(line) == 0 { continue } if line[0] == '{' { var event map[string]interface{} err := json.Unmarshal([]byte(line), &event) if err != nil { return "", fmt.Errorf("failed to unmarshal event data '%s': %v", line, err) } eventType, ok := event["type"].(string) if !ok { return "", fmt.Errorf("invalid event: %s", line) } switch eventType { case "error": return sb.String(), fmt.Errorf("an error occurred: %s", event["error"]) default: return sb.String(), fmt.Errorf("unknown event type: %s", eventType) } } else if strings.HasPrefix(line, "data:") { data := strings.TrimSpace(strings.TrimPrefix(line, "data:")) var event map[string]interface{} err := json.Unmarshal([]byte(data), &event) if err != nil { return "", fmt.Errorf("failed to unmarshal event data: %v", err) } eventType, ok := event["type"].(string) if !ok { return "", fmt.Errorf("invalid event type") } switch eventType { case "message_start": // noop case "ping": // signals start of text - currently ignoring case "content_block_start": // ignore? case "content_block_delta": delta, ok := event["delta"].(map[string]interface{}) if !ok { return "", fmt.Errorf("invalid content block delta") } text, ok := delta["text"].(string) if !ok { return "", fmt.Errorf("invalid text delta") } sb.WriteString(text) output <- api.Chunk{ Content: text, } case "content_block_stop": // ignore? case "message_delta": delta, ok := event["delta"].(map[string]interface{}) if !ok { return "", fmt.Errorf("invalid message delta") } stopReason, ok := delta["stop_reason"].(string) if ok && stopReason == "stop_sequence" { stopSequence, ok := delta["stop_sequence"].(string) if ok && stopSequence == FUNCTION_STOP_SEQUENCE { content := sb.String() start := strings.Index(content, "") if start == -1 { return content, fmt.Errorf("reached stop sequence but no opening tag found") } sb.WriteString(FUNCTION_STOP_SEQUENCE) output <- api.Chunk{ Content: FUNCTION_STOP_SEQUENCE, } funcCallXml := content[start:] + FUNCTION_STOP_SEQUENCE var functionCalls XMLFunctionCalls err := xml.Unmarshal([]byte(funcCallXml), &functionCalls) if err != nil { return "", fmt.Errorf("failed to unmarshal function_calls: %v", err) } toolCall := model.Message{ Role: model.MessageRoleToolCall, // function call xml stripped from content for model interop Content: strings.TrimSpace(content[:start]), ToolCalls: convertXMLFunctionCallsToToolCalls(functionCalls), } toolResults, err := tools.ExecuteToolCalls(toolCall.ToolCalls, params.ToolBag) if err != nil { return "", err } toolResult := model.Message{ Role: model.MessageRoleToolResult, ToolResults: toolResults, } if callback != nil { callback(toolCall) callback(toolResult) } if continuation { messages[len(messages)-1] = toolCall } else { messages = append(messages, toolCall) } messages = append(messages, toolResult) return c.CreateChatCompletionStream(ctx, params, messages, callback, output) } } case "message_stop": // return the completed message content := sb.String() if callback != nil { callback(model.Message{ Role: model.MessageRoleAssistant, Content: content, }) } return content, nil case "error": return sb.String(), fmt.Errorf("an error occurred: %s", event["error"]) default: fmt.Printf("\nUnrecognized event: %s\n", data) } } } if err := scanner.Err(); err != nil { return "", fmt.Errorf("failed to read response body: %v", err) } return "", fmt.Errorf("unexpected end of stream") }