Compare commits
5 Commits
e9ce46a250
...
3f765234de
Author | SHA1 | Date | |
---|---|---|---|
3f765234de | |||
21411c2732 | |||
99794addee | |||
a47c1a76b4 | |||
96fdae982e |
@ -11,7 +11,6 @@ import (
|
||||
"strings"
|
||||
|
||||
"git.mlow.ca/mlow/lmcli/pkg/lmcli/model"
|
||||
"git.mlow.ca/mlow/lmcli/pkg/lmcli/provider"
|
||||
"git.mlow.ca/mlow/lmcli/pkg/lmcli/tools"
|
||||
)
|
||||
|
||||
@ -21,7 +20,7 @@ type AnthropicClient struct {
|
||||
|
||||
type Message struct {
|
||||
Role string `json:"role"`
|
||||
Content string `json:"content"`
|
||||
OriginalContent string `json:"content"`
|
||||
}
|
||||
|
||||
type Request struct {
|
||||
@ -42,10 +41,10 @@ type OriginalContent struct {
|
||||
}
|
||||
|
||||
type Response struct {
|
||||
Id string `json:"id"`
|
||||
Type string `json:"type"`
|
||||
Role string `json:"role"`
|
||||
Content []OriginalContent `json:"content"`
|
||||
Id string `json:"id"`
|
||||
Type string `json:"type"`
|
||||
Role string `json:"role"`
|
||||
OriginalContent []OriginalContent `json:"content"`
|
||||
}
|
||||
|
||||
const FUNCTION_STOP_SEQUENCE = "</function_calls>"
|
||||
@ -66,7 +65,7 @@ func buildRequest(params model.RequestParameters, messages []model.Message) Requ
|
||||
}
|
||||
|
||||
startIdx := 0
|
||||
if len(messages) > 0 && messages[0].Role == model.MessageRoleSystem {
|
||||
if messages[0].Role == model.MessageRoleSystem {
|
||||
requestBody.System = messages[0].Content
|
||||
requestBody.Messages = requestBody.Messages[:len(messages)-1]
|
||||
startIdx = 1
|
||||
@ -86,15 +85,8 @@ func buildRequest(params model.RequestParameters, messages []model.Message) Requ
|
||||
switch msg.Role {
|
||||
case model.MessageRoleToolCall:
|
||||
message.Role = "assistant"
|
||||
if msg.Content != "" {
|
||||
message.Content = msg.Content
|
||||
}
|
||||
xmlFuncCalls := convertToolCallsToXMLFunctionCalls(msg.ToolCalls)
|
||||
xmlString, err := xmlFuncCalls.XMLString()
|
||||
if err != nil {
|
||||
panic("Could not serialize []ToolCall to XMLFunctionCall")
|
||||
}
|
||||
message.Content += xmlString
|
||||
message.OriginalContent = msg.Content
|
||||
//message.ToolCalls = convertToolCallToOpenAI(m.ToolCalls)
|
||||
case model.MessageRoleToolResult:
|
||||
xmlFuncResults := convertToolResultsToXMLFunctionResult(msg.ToolResults)
|
||||
xmlString, err := xmlFuncResults.XMLString()
|
||||
@ -102,10 +94,10 @@ func buildRequest(params model.RequestParameters, messages []model.Message) Requ
|
||||
panic("Could not serialize []ToolResult to XMLFunctionResults")
|
||||
}
|
||||
message.Role = "user"
|
||||
message.Content = xmlString
|
||||
message.OriginalContent = xmlString
|
||||
default:
|
||||
message.Role = string(msg.Role)
|
||||
message.Content = msg.Content
|
||||
message.OriginalContent = msg.Content
|
||||
}
|
||||
}
|
||||
return requestBody
|
||||
@ -158,7 +150,7 @@ func (c *AnthropicClient) CreateChatCompletion(
|
||||
}
|
||||
|
||||
sb := strings.Builder{}
|
||||
for _, content := range response.Content {
|
||||
for _, content := range response.OriginalContent {
|
||||
var reply model.Message
|
||||
switch content.Type {
|
||||
case "text":
|
||||
@ -286,9 +278,8 @@ func (c *AnthropicClient) CreateChatCompletionStream(
|
||||
|
||||
// Execute function calls
|
||||
toolCall := model.Message{
|
||||
Role: model.MessageRoleToolCall,
|
||||
// xml stripped from content
|
||||
Content: content[:start],
|
||||
Role: model.MessageRoleToolCall,
|
||||
Content: content,
|
||||
ToolCalls: convertXMLFunctionCallsToToolCalls(functionCalls),
|
||||
}
|
||||
|
||||
|
@ -2,7 +2,6 @@ package anthropic
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"fmt"
|
||||
"strings"
|
||||
"text/template"
|
||||
|
||||
@ -115,25 +114,6 @@ func convertXMLFunctionCallsToToolCalls(functionCalls XMLFunctionCalls) []model.
|
||||
return toolCalls
|
||||
}
|
||||
|
||||
func convertToolCallsToXMLFunctionCalls(toolCalls []model.ToolCall) XMLFunctionCalls {
|
||||
converted := make([]XMLFunctionInvoke, len(toolCalls))
|
||||
for i, toolCall := range toolCalls {
|
||||
var params XMLFunctionInvokeParameters
|
||||
var paramXML string
|
||||
for key, value := range toolCall.Parameters {
|
||||
paramXML += fmt.Sprintf("<%s>%v</%s>\n", key, value, key)
|
||||
}
|
||||
params.String = paramXML
|
||||
converted[i] = XMLFunctionInvoke{
|
||||
ToolName: toolCall.Name,
|
||||
Parameters: params,
|
||||
}
|
||||
}
|
||||
return XMLFunctionCalls{
|
||||
Invoke: converted,
|
||||
}
|
||||
}
|
||||
|
||||
func convertToolResultsToXMLFunctionResult(toolResults []model.ToolResult) XMLFunctionResults {
|
||||
converted := make([]XMLFunctionResult, len(toolResults))
|
||||
for i, result := range toolResults {
|
||||
@ -200,22 +180,3 @@ func (x XMLFunctionResults) XMLString() (string, error) {
|
||||
|
||||
return buf.String(), nil
|
||||
}
|
||||
|
||||
func (x XMLFunctionCalls) XMLString() (string, error) {
|
||||
tmpl, err := template.New("function_calls").Parse(`<function_calls>
|
||||
{{range .Invoke}}<invoke>
|
||||
<tool_name>{{.ToolName}}</tool_name>
|
||||
<parameters>{{.Parameters.String}}</parameters>
|
||||
</invoke>
|
||||
{{end}}</function_calls>`)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
var buf bytes.Buffer
|
||||
if err := tmpl.Execute(&buf, x); err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
return buf.String(), nil
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user