Matt Low
0384c7cb66
This refactor splits out all conversation concerns into a new `conversation` package. There is now a split between `conversation` and `api`s representation of `Message`, the latter storing the minimum information required for interaction with LLM providers. There is necessary conversation between the two when making LLM calls.
448 lines
12 KiB
Go
448 lines
12 KiB
Go
package anthropic
|
|
|
|
import (
|
|
"bufio"
|
|
"bytes"
|
|
"context"
|
|
"encoding/json"
|
|
"fmt"
|
|
"io"
|
|
"net/http"
|
|
"strings"
|
|
|
|
"git.mlow.ca/mlow/lmcli/pkg/api"
|
|
"git.mlow.ca/mlow/lmcli/pkg/provider"
|
|
)
|
|
|
|
const ANTHROPIC_VERSION = "2023-06-01"
|
|
|
|
type AnthropicClient struct {
|
|
APIKey string
|
|
BaseURL string
|
|
}
|
|
|
|
type ChatCompletionMessage struct {
|
|
Role string `json:"role"`
|
|
Content interface{} `json:"content"`
|
|
}
|
|
|
|
type Tool struct {
|
|
Name string `json:"name"`
|
|
Description string `json:"description"`
|
|
InputSchema InputSchema `json:"input_schema"`
|
|
}
|
|
|
|
type InputSchema struct {
|
|
Type string `json:"type"`
|
|
Properties map[string]Property `json:"properties"`
|
|
Required []string `json:"required"`
|
|
}
|
|
|
|
type Property struct {
|
|
Type string `json:"type"`
|
|
Description string `json:"description"`
|
|
Enum []string `json:"enum,omitempty"`
|
|
}
|
|
|
|
type ChatCompletionRequest struct {
|
|
Model string `json:"model"`
|
|
Messages []ChatCompletionMessage `json:"messages"`
|
|
System string `json:"system,omitempty"`
|
|
Tools []Tool `json:"tools,omitempty"`
|
|
MaxTokens int `json:"max_tokens"`
|
|
Temperature float32 `json:"temperature,omitempty"`
|
|
Stream bool `json:"stream"`
|
|
}
|
|
|
|
type ContentBlock struct {
|
|
Type string `json:"type"`
|
|
Text string `json:"text,omitempty"`
|
|
ID string `json:"id,omitempty"`
|
|
Name string `json:"name,omitempty"`
|
|
Input interface{} `json:"input,omitempty"`
|
|
partialJsonAccumulator string
|
|
}
|
|
|
|
type ChatCompletionResponse struct {
|
|
ID string `json:"id"`
|
|
Type string `json:"type"`
|
|
Role string `json:"role"`
|
|
Model string `json:"model"`
|
|
Content []ContentBlock `json:"content"`
|
|
StopReason string `json:"stop_reason"`
|
|
Usage Usage `json:"usage"`
|
|
}
|
|
|
|
type Usage struct {
|
|
InputTokens int `json:"input_tokens"`
|
|
OutputTokens int `json:"output_tokens"`
|
|
}
|
|
|
|
type StreamEvent struct {
|
|
Type string `json:"type"`
|
|
Message interface{} `json:"message,omitempty"`
|
|
Index int `json:"index,omitempty"`
|
|
Delta interface{} `json:"delta,omitempty"`
|
|
}
|
|
|
|
func convertTools(tools []api.ToolSpec) []Tool {
|
|
anthropicTools := make([]Tool, len(tools))
|
|
for i, tool := range tools {
|
|
properties := make(map[string]Property)
|
|
for _, param := range tool.Parameters {
|
|
properties[param.Name] = Property{
|
|
Type: param.Type,
|
|
Description: param.Description,
|
|
Enum: param.Enum,
|
|
}
|
|
}
|
|
|
|
var required []string
|
|
for _, param := range tool.Parameters {
|
|
if param.Required {
|
|
required = append(required, param.Name)
|
|
}
|
|
}
|
|
|
|
anthropicTools[i] = Tool{
|
|
Name: tool.Name,
|
|
Description: tool.Description,
|
|
InputSchema: InputSchema{
|
|
Type: "object",
|
|
Properties: properties,
|
|
Required: required,
|
|
},
|
|
}
|
|
}
|
|
return anthropicTools
|
|
}
|
|
|
|
func createChatCompletionRequest(
|
|
params provider.RequestParameters,
|
|
messages []api.Message,
|
|
) (string, ChatCompletionRequest) {
|
|
requestMessages := make([]ChatCompletionMessage, 0, len(messages))
|
|
var systemMessage string
|
|
|
|
for _, m := range messages {
|
|
if m.Role == api.MessageRoleSystem {
|
|
systemMessage = m.Content
|
|
continue
|
|
}
|
|
|
|
var content interface{}
|
|
role := string(m.Role)
|
|
|
|
switch m.Role {
|
|
case api.MessageRoleToolCall:
|
|
role = "assistant"
|
|
contentBlocks := make([]map[string]interface{}, 0)
|
|
if m.Content != "" {
|
|
contentBlocks = append(contentBlocks, map[string]interface{}{
|
|
"type": "text",
|
|
"text": m.Content,
|
|
})
|
|
}
|
|
for _, toolCall := range m.ToolCalls {
|
|
contentBlocks = append(contentBlocks, map[string]interface{}{
|
|
"type": "tool_use",
|
|
"id": toolCall.ID,
|
|
"name": toolCall.Name,
|
|
"input": toolCall.Parameters,
|
|
})
|
|
}
|
|
content = contentBlocks
|
|
|
|
case api.MessageRoleToolResult:
|
|
role = "user"
|
|
contentBlocks := make([]map[string]interface{}, 0)
|
|
for _, result := range m.ToolResults {
|
|
contentBlock := map[string]interface{}{
|
|
"type": "tool_result",
|
|
"tool_use_id": result.ToolCallID,
|
|
"content": result.Result,
|
|
}
|
|
contentBlocks = append(contentBlocks, contentBlock)
|
|
}
|
|
content = contentBlocks
|
|
|
|
default:
|
|
content = m.Content
|
|
}
|
|
|
|
requestMessages = append(requestMessages, ChatCompletionMessage{
|
|
Role: role,
|
|
Content: content,
|
|
})
|
|
}
|
|
|
|
request := ChatCompletionRequest{
|
|
Model: params.Model,
|
|
Messages: requestMessages,
|
|
System: systemMessage,
|
|
MaxTokens: params.MaxTokens,
|
|
Temperature: params.Temperature,
|
|
}
|
|
|
|
if len(params.Toolbox) > 0 {
|
|
request.Tools = convertTools(params.Toolbox)
|
|
}
|
|
|
|
var prefill string
|
|
if len(messages) > 0 && messages[len(messages)-1].Role == api.MessageRoleAssistant {
|
|
// Prompting on an assitant message, use its content as prefill
|
|
prefill = messages[len(messages)-1].Content
|
|
}
|
|
|
|
return prefill, request
|
|
}
|
|
|
|
func (c *AnthropicClient) sendRequest(ctx context.Context, r ChatCompletionRequest) (*http.Response, error) {
|
|
jsonData, err := json.Marshal(r)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to marshal request: %w", err)
|
|
}
|
|
|
|
req, err := http.NewRequestWithContext(ctx, "POST", c.BaseURL+"/v1/messages", bytes.NewBuffer(jsonData))
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to create HTTP request: %w", err)
|
|
}
|
|
|
|
req.Header.Set("x-api-key", c.APIKey)
|
|
req.Header.Set("anthropic-version", ANTHROPIC_VERSION)
|
|
req.Header.Set("content-type", "application/json")
|
|
|
|
client := &http.Client{}
|
|
resp, err := client.Do(req)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
if resp.StatusCode != 200 {
|
|
bytes, _ := io.ReadAll(resp.Body)
|
|
return resp, fmt.Errorf("%v", string(bytes))
|
|
}
|
|
|
|
return resp, err
|
|
}
|
|
|
|
func (c *AnthropicClient) CreateChatCompletion(
|
|
ctx context.Context,
|
|
params provider.RequestParameters,
|
|
messages []api.Message,
|
|
) (*api.Message, error) {
|
|
if len(messages) == 0 {
|
|
return nil, fmt.Errorf("can't create completion from no messages")
|
|
}
|
|
|
|
_, req := createChatCompletionRequest(params, messages)
|
|
req.Stream = false
|
|
|
|
resp, err := c.sendRequest(ctx, req)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to send request: %w", err)
|
|
}
|
|
defer resp.Body.Close()
|
|
|
|
var completionResp ChatCompletionResponse
|
|
err = json.NewDecoder(resp.Body).Decode(&completionResp)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to decode response: %w", err)
|
|
}
|
|
|
|
return convertResponseToMessage(completionResp)
|
|
}
|
|
|
|
func (c *AnthropicClient) CreateChatCompletionStream(
|
|
ctx context.Context,
|
|
params provider.RequestParameters,
|
|
messages []api.Message,
|
|
output chan<- provider.Chunk,
|
|
) (*api.Message, error) {
|
|
if len(messages) == 0 {
|
|
return nil, fmt.Errorf("can't create completion from no messages")
|
|
}
|
|
|
|
prefill, req := createChatCompletionRequest(params, messages)
|
|
req.Stream = true
|
|
|
|
resp, err := c.sendRequest(ctx, req)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to send request: %w", err)
|
|
}
|
|
defer resp.Body.Close()
|
|
|
|
contentBlocks := make(map[int]*ContentBlock)
|
|
var finalMessage *ChatCompletionResponse
|
|
|
|
var firstChunkReceived bool
|
|
|
|
reader := bufio.NewReader(resp.Body)
|
|
for {
|
|
select {
|
|
case <-ctx.Done():
|
|
return nil, ctx.Err()
|
|
default:
|
|
line, err := reader.ReadBytes('\n')
|
|
if err != nil {
|
|
if err == io.EOF {
|
|
break
|
|
}
|
|
return nil, fmt.Errorf("error reading stream: %w", err)
|
|
}
|
|
|
|
line = bytes.TrimSpace(line)
|
|
if len(line) == 0 || !bytes.HasPrefix(line, []byte("data: ")) {
|
|
continue
|
|
}
|
|
|
|
line = bytes.TrimPrefix(line, []byte("data: "))
|
|
|
|
var streamEvent StreamEvent
|
|
err = json.Unmarshal(line, &streamEvent)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to unmarshal stream event: %w", err)
|
|
}
|
|
|
|
switch streamEvent.Type {
|
|
case "message_start":
|
|
finalMessage = &ChatCompletionResponse{}
|
|
err = json.Unmarshal(line, &struct {
|
|
Message *ChatCompletionResponse `json:"message"`
|
|
}{Message: finalMessage})
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to unmarshal message_start: %w", err)
|
|
}
|
|
case "content_block_start":
|
|
var contentBlockStart struct {
|
|
Index int `json:"index"`
|
|
ContentBlock ContentBlock `json:"content_block"`
|
|
}
|
|
err = json.Unmarshal(line, &contentBlockStart)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to unmarshal content_block_start: %w", err)
|
|
}
|
|
|
|
contentBlocks[contentBlockStart.Index] = &contentBlockStart.ContentBlock
|
|
case "content_block_delta":
|
|
if streamEvent.Index >= len(contentBlocks) {
|
|
return nil, fmt.Errorf("received delta for non-existent content block index: %d", streamEvent.Index)
|
|
}
|
|
|
|
block := contentBlocks[streamEvent.Index]
|
|
delta, ok := streamEvent.Delta.(map[string]interface{})
|
|
if !ok {
|
|
return nil, fmt.Errorf("unexpected delta type: %T", streamEvent.Delta)
|
|
}
|
|
|
|
deltaType, ok := delta["type"].(string)
|
|
if !ok {
|
|
return nil, fmt.Errorf("delta missing type field")
|
|
}
|
|
|
|
switch deltaType {
|
|
case "text_delta":
|
|
if text, ok := delta["text"].(string); ok {
|
|
if !firstChunkReceived {
|
|
if prefill == "" {
|
|
// if there is no prefil, ensure we trim leading whitespace
|
|
text = strings.TrimSpace(text)
|
|
}
|
|
firstChunkReceived = true
|
|
}
|
|
block.Text += text
|
|
output <- provider.Chunk{
|
|
Content: text,
|
|
// rough, anthropic performs some chunking
|
|
TokenCount: uint(len(strings.Split(text, " "))),
|
|
}
|
|
}
|
|
case "input_json_delta":
|
|
if block.Type != "tool_use" {
|
|
return nil, fmt.Errorf("received input_json_delta for non-tool_use block")
|
|
}
|
|
if partialJSON, ok := delta["partial_json"].(string); ok {
|
|
block.partialJsonAccumulator += partialJSON
|
|
}
|
|
}
|
|
|
|
case "content_block_stop":
|
|
if streamEvent.Index >= len(contentBlocks) {
|
|
return nil, fmt.Errorf("received stop for non-existent content block index: %d", streamEvent.Index)
|
|
}
|
|
|
|
block := contentBlocks[streamEvent.Index]
|
|
if block.Type == "tool_use" && block.partialJsonAccumulator != "" {
|
|
var inputData map[string]interface{}
|
|
err := json.Unmarshal([]byte(block.partialJsonAccumulator), &inputData)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to unmarshal accumulated JSON for tool use: %w", err)
|
|
}
|
|
block.Input = inputData
|
|
}
|
|
case "message_delta":
|
|
if finalMessage == nil {
|
|
return nil, fmt.Errorf("received message_delta before message_start")
|
|
}
|
|
delta, ok := streamEvent.Delta.(map[string]interface{})
|
|
if !ok {
|
|
return nil, fmt.Errorf("unexpected delta type in message_delta: %T", streamEvent.Delta)
|
|
}
|
|
if stopReason, ok := delta["stop_reason"].(string); ok {
|
|
finalMessage.StopReason = stopReason
|
|
}
|
|
|
|
case "message_stop":
|
|
// End of the stream
|
|
goto END_STREAM
|
|
|
|
case "error":
|
|
return nil, fmt.Errorf("received error event: %v", streamEvent.Message)
|
|
|
|
default:
|
|
// Ignore unknown event types
|
|
}
|
|
}
|
|
}
|
|
|
|
END_STREAM:
|
|
if finalMessage == nil {
|
|
return nil, fmt.Errorf("no final message received")
|
|
}
|
|
|
|
finalMessage.Content = make([]ContentBlock, len(contentBlocks))
|
|
for _, v := range contentBlocks {
|
|
finalMessage.Content = append(finalMessage.Content, *v)
|
|
}
|
|
|
|
return convertResponseToMessage(*finalMessage)
|
|
}
|
|
|
|
func convertResponseToMessage(resp ChatCompletionResponse) (*api.Message, error) {
|
|
content := strings.Builder{}
|
|
var toolCalls []api.ToolCall
|
|
|
|
for _, block := range resp.Content {
|
|
switch block.Type {
|
|
case "text":
|
|
content.WriteString(block.Text)
|
|
case "tool_use":
|
|
parameters, ok := block.Input.(map[string]interface{})
|
|
if !ok {
|
|
return nil, fmt.Errorf("unexpected type for tool call parameters: %T", block.Input)
|
|
}
|
|
toolCalls = append(toolCalls, api.ToolCall{
|
|
ID: block.ID,
|
|
Name: block.Name,
|
|
Parameters: parameters,
|
|
})
|
|
}
|
|
}
|
|
|
|
if len(toolCalls) > 0 {
|
|
return api.NewMessageWithToolCalls(content.String(), toolCalls), nil
|
|
}
|
|
|
|
return api.NewMessageWithAssistant(content.String()), nil
|
|
}
|