Compare commits
6 Commits
3f765234de
...
e9ce46a250
Author | SHA1 | Date | |
---|---|---|---|
e9ce46a250 | |||
79ff77a73f | |||
407f4a99bf | |||
79ad681f4c | |||
78c80431b4 | |||
32eab7aa35 |
@ -11,6 +11,7 @@ import (
|
|||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
"git.mlow.ca/mlow/lmcli/pkg/lmcli/model"
|
"git.mlow.ca/mlow/lmcli/pkg/lmcli/model"
|
||||||
|
"git.mlow.ca/mlow/lmcli/pkg/lmcli/provider"
|
||||||
"git.mlow.ca/mlow/lmcli/pkg/lmcli/tools"
|
"git.mlow.ca/mlow/lmcli/pkg/lmcli/tools"
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -20,7 +21,7 @@ type AnthropicClient struct {
|
|||||||
|
|
||||||
type Message struct {
|
type Message struct {
|
||||||
Role string `json:"role"`
|
Role string `json:"role"`
|
||||||
OriginalContent string `json:"content"`
|
Content string `json:"content"`
|
||||||
}
|
}
|
||||||
|
|
||||||
type Request struct {
|
type Request struct {
|
||||||
@ -41,10 +42,10 @@ type OriginalContent struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
type Response struct {
|
type Response struct {
|
||||||
Id string `json:"id"`
|
Id string `json:"id"`
|
||||||
Type string `json:"type"`
|
Type string `json:"type"`
|
||||||
Role string `json:"role"`
|
Role string `json:"role"`
|
||||||
OriginalContent []OriginalContent `json:"content"`
|
Content []OriginalContent `json:"content"`
|
||||||
}
|
}
|
||||||
|
|
||||||
const FUNCTION_STOP_SEQUENCE = "</function_calls>"
|
const FUNCTION_STOP_SEQUENCE = "</function_calls>"
|
||||||
@ -65,7 +66,7 @@ func buildRequest(params model.RequestParameters, messages []model.Message) Requ
|
|||||||
}
|
}
|
||||||
|
|
||||||
startIdx := 0
|
startIdx := 0
|
||||||
if messages[0].Role == model.MessageRoleSystem {
|
if len(messages) > 0 && messages[0].Role == model.MessageRoleSystem {
|
||||||
requestBody.System = messages[0].Content
|
requestBody.System = messages[0].Content
|
||||||
requestBody.Messages = requestBody.Messages[:len(messages)-1]
|
requestBody.Messages = requestBody.Messages[:len(messages)-1]
|
||||||
startIdx = 1
|
startIdx = 1
|
||||||
@ -85,8 +86,15 @@ func buildRequest(params model.RequestParameters, messages []model.Message) Requ
|
|||||||
switch msg.Role {
|
switch msg.Role {
|
||||||
case model.MessageRoleToolCall:
|
case model.MessageRoleToolCall:
|
||||||
message.Role = "assistant"
|
message.Role = "assistant"
|
||||||
message.OriginalContent = msg.Content
|
if msg.Content != "" {
|
||||||
//message.ToolCalls = convertToolCallToOpenAI(m.ToolCalls)
|
message.Content = msg.Content
|
||||||
|
}
|
||||||
|
xmlFuncCalls := convertToolCallsToXMLFunctionCalls(msg.ToolCalls)
|
||||||
|
xmlString, err := xmlFuncCalls.XMLString()
|
||||||
|
if err != nil {
|
||||||
|
panic("Could not serialize []ToolCall to XMLFunctionCall")
|
||||||
|
}
|
||||||
|
message.Content += xmlString
|
||||||
case model.MessageRoleToolResult:
|
case model.MessageRoleToolResult:
|
||||||
xmlFuncResults := convertToolResultsToXMLFunctionResult(msg.ToolResults)
|
xmlFuncResults := convertToolResultsToXMLFunctionResult(msg.ToolResults)
|
||||||
xmlString, err := xmlFuncResults.XMLString()
|
xmlString, err := xmlFuncResults.XMLString()
|
||||||
@ -94,10 +102,10 @@ func buildRequest(params model.RequestParameters, messages []model.Message) Requ
|
|||||||
panic("Could not serialize []ToolResult to XMLFunctionResults")
|
panic("Could not serialize []ToolResult to XMLFunctionResults")
|
||||||
}
|
}
|
||||||
message.Role = "user"
|
message.Role = "user"
|
||||||
message.OriginalContent = xmlString
|
message.Content = xmlString
|
||||||
default:
|
default:
|
||||||
message.Role = string(msg.Role)
|
message.Role = string(msg.Role)
|
||||||
message.OriginalContent = msg.Content
|
message.Content = msg.Content
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return requestBody
|
return requestBody
|
||||||
@ -150,7 +158,7 @@ func (c *AnthropicClient) CreateChatCompletion(
|
|||||||
}
|
}
|
||||||
|
|
||||||
sb := strings.Builder{}
|
sb := strings.Builder{}
|
||||||
for _, content := range response.OriginalContent {
|
for _, content := range response.Content {
|
||||||
var reply model.Message
|
var reply model.Message
|
||||||
switch content.Type {
|
switch content.Type {
|
||||||
case "text":
|
case "text":
|
||||||
@ -278,8 +286,9 @@ func (c *AnthropicClient) CreateChatCompletionStream(
|
|||||||
|
|
||||||
// Execute function calls
|
// Execute function calls
|
||||||
toolCall := model.Message{
|
toolCall := model.Message{
|
||||||
Role: model.MessageRoleToolCall,
|
Role: model.MessageRoleToolCall,
|
||||||
Content: content,
|
// xml stripped from content
|
||||||
|
Content: content[:start],
|
||||||
ToolCalls: convertXMLFunctionCallsToToolCalls(functionCalls),
|
ToolCalls: convertXMLFunctionCallsToToolCalls(functionCalls),
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -2,6 +2,7 @@ package anthropic
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"bytes"
|
"bytes"
|
||||||
|
"fmt"
|
||||||
"strings"
|
"strings"
|
||||||
"text/template"
|
"text/template"
|
||||||
|
|
||||||
@ -114,6 +115,25 @@ func convertXMLFunctionCallsToToolCalls(functionCalls XMLFunctionCalls) []model.
|
|||||||
return toolCalls
|
return toolCalls
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func convertToolCallsToXMLFunctionCalls(toolCalls []model.ToolCall) XMLFunctionCalls {
|
||||||
|
converted := make([]XMLFunctionInvoke, len(toolCalls))
|
||||||
|
for i, toolCall := range toolCalls {
|
||||||
|
var params XMLFunctionInvokeParameters
|
||||||
|
var paramXML string
|
||||||
|
for key, value := range toolCall.Parameters {
|
||||||
|
paramXML += fmt.Sprintf("<%s>%v</%s>\n", key, value, key)
|
||||||
|
}
|
||||||
|
params.String = paramXML
|
||||||
|
converted[i] = XMLFunctionInvoke{
|
||||||
|
ToolName: toolCall.Name,
|
||||||
|
Parameters: params,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return XMLFunctionCalls{
|
||||||
|
Invoke: converted,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func convertToolResultsToXMLFunctionResult(toolResults []model.ToolResult) XMLFunctionResults {
|
func convertToolResultsToXMLFunctionResult(toolResults []model.ToolResult) XMLFunctionResults {
|
||||||
converted := make([]XMLFunctionResult, len(toolResults))
|
converted := make([]XMLFunctionResult, len(toolResults))
|
||||||
for i, result := range toolResults {
|
for i, result := range toolResults {
|
||||||
@ -180,3 +200,22 @@ func (x XMLFunctionResults) XMLString() (string, error) {
|
|||||||
|
|
||||||
return buf.String(), nil
|
return buf.String(), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (x XMLFunctionCalls) XMLString() (string, error) {
|
||||||
|
tmpl, err := template.New("function_calls").Parse(`<function_calls>
|
||||||
|
{{range .Invoke}}<invoke>
|
||||||
|
<tool_name>{{.ToolName}}</tool_name>
|
||||||
|
<parameters>{{.Parameters.String}}</parameters>
|
||||||
|
</invoke>
|
||||||
|
{{end}}</function_calls>`)
|
||||||
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
|
||||||
|
var buf bytes.Buffer
|
||||||
|
if err := tmpl.Execute(&buf, x); err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
|
||||||
|
return buf.String(), nil
|
||||||
|
}
|
||||||
|
Loading…
Reference in New Issue
Block a user