diff --git a/pkg/lmcli/provider/anthropic/anthropic.go b/pkg/lmcli/provider/anthropic/anthropic.go index d889428..7ed1645 100644 --- a/pkg/lmcli/provider/anthropic/anthropic.go +++ b/pkg/lmcli/provider/anthropic/anthropic.go @@ -11,6 +11,7 @@ import ( "strings" "git.mlow.ca/mlow/lmcli/pkg/lmcli/model" + "git.mlow.ca/mlow/lmcli/pkg/lmcli/provider" "git.mlow.ca/mlow/lmcli/pkg/lmcli/tools" ) @@ -20,7 +21,7 @@ type AnthropicClient struct { type Message struct { Role string `json:"role"` - OriginalContent string `json:"content"` + Content string `json:"content"` } type Request struct { @@ -41,10 +42,10 @@ type OriginalContent struct { } type Response struct { - Id string `json:"id"` - Type string `json:"type"` - Role string `json:"role"` - OriginalContent []OriginalContent `json:"content"` + Id string `json:"id"` + Type string `json:"type"` + Role string `json:"role"` + Content []OriginalContent `json:"content"` } const FUNCTION_STOP_SEQUENCE = "" @@ -65,7 +66,7 @@ func buildRequest(params model.RequestParameters, messages []model.Message) Requ } startIdx := 0 - if messages[0].Role == model.MessageRoleSystem { + if len(messages) > 0 && messages[0].Role == model.MessageRoleSystem { requestBody.System = messages[0].Content requestBody.Messages = requestBody.Messages[:len(messages)-1] startIdx = 1 @@ -85,8 +86,15 @@ func buildRequest(params model.RequestParameters, messages []model.Message) Requ switch msg.Role { case model.MessageRoleToolCall: message.Role = "assistant" - message.OriginalContent = msg.Content - //message.ToolCalls = convertToolCallToOpenAI(m.ToolCalls) + if msg.Content != "" { + message.Content = msg.Content + } + xmlFuncCalls := convertToolCallsToXMLFunctionCalls(msg.ToolCalls) + xmlString, err := xmlFuncCalls.XMLString() + if err != nil { + panic("Could not serialize []ToolCall to XMLFunctionCall") + } + message.Content += xmlString case model.MessageRoleToolResult: xmlFuncResults := convertToolResultsToXMLFunctionResult(msg.ToolResults) xmlString, err := xmlFuncResults.XMLString() @@ -94,10 +102,10 @@ func buildRequest(params model.RequestParameters, messages []model.Message) Requ panic("Could not serialize []ToolResult to XMLFunctionResults") } message.Role = "user" - message.OriginalContent = xmlString + message.Content = xmlString default: message.Role = string(msg.Role) - message.OriginalContent = msg.Content + message.Content = msg.Content } } return requestBody @@ -150,7 +158,7 @@ func (c *AnthropicClient) CreateChatCompletion( } sb := strings.Builder{} - for _, content := range response.OriginalContent { + for _, content := range response.Content { var reply model.Message switch content.Type { case "text": @@ -278,8 +286,9 @@ func (c *AnthropicClient) CreateChatCompletionStream( // Execute function calls toolCall := model.Message{ - Role: model.MessageRoleToolCall, - Content: content, + Role: model.MessageRoleToolCall, + // xml stripped from content + Content: content[:start], ToolCalls: convertXMLFunctionCallsToToolCalls(functionCalls), } diff --git a/pkg/lmcli/provider/anthropic/tools.go b/pkg/lmcli/provider/anthropic/tools.go index de89d4a..7c52b1e 100644 --- a/pkg/lmcli/provider/anthropic/tools.go +++ b/pkg/lmcli/provider/anthropic/tools.go @@ -2,6 +2,7 @@ package anthropic import ( "bytes" + "fmt" "strings" "text/template" @@ -114,6 +115,25 @@ func convertXMLFunctionCallsToToolCalls(functionCalls XMLFunctionCalls) []model. return toolCalls } +func convertToolCallsToXMLFunctionCalls(toolCalls []model.ToolCall) XMLFunctionCalls { + converted := make([]XMLFunctionInvoke, len(toolCalls)) + for i, toolCall := range toolCalls { + var params XMLFunctionInvokeParameters + var paramXML string + for key, value := range toolCall.Parameters { + paramXML += fmt.Sprintf("<%s>%v\n", key, value, key) + } + params.String = paramXML + converted[i] = XMLFunctionInvoke{ + ToolName: toolCall.Name, + Parameters: params, + } + } + return XMLFunctionCalls{ + Invoke: converted, + } +} + func convertToolResultsToXMLFunctionResult(toolResults []model.ToolResult) XMLFunctionResults { converted := make([]XMLFunctionResult, len(toolResults)) for i, result := range toolResults { @@ -180,3 +200,22 @@ func (x XMLFunctionResults) XMLString() (string, error) { return buf.String(), nil } + +func (x XMLFunctionCalls) XMLString() (string, error) { + tmpl, err := template.New("function_calls").Parse(` +{{range .Invoke}} +{{.ToolName}} +{{.Parameters.String}} + +{{end}}`) + if err != nil { + return "", err + } + + var buf bytes.Buffer + if err := tmpl.Execute(&buf, x); err != nil { + return "", err + } + + return buf.String(), nil +}