lmcli/pkg/api/provider/anthropic/tools.go

233 lines
6.4 KiB
Go
Raw Normal View History

package anthropic
import (
"bytes"
"fmt"
"strings"
"text/template"
"git.mlow.ca/mlow/lmcli/pkg/api"
)
2024-04-29 00:16:41 -06:00
const FUNCTION_STOP_SEQUENCE = "</function_calls>"
const TOOL_PREAMBLE = `You have access to the following tools when replying.
You may call them like this:
<function_calls>
<invoke>
<tool_name>$TOOL_NAME</tool_name>
<parameters>
<$PARAMETER_NAME>$PARAMETER_VALUE</$PARAMETER_NAME>
...
</parameters>
</invoke>
</function_calls>
Here are the tools available:`
const TOOL_PREAMBLE_FOOTER = `Recognize the utility of these tools in a broad range of different applications, and the power they give you to solve a wide range of different problems. However, ensure that the tools are used judiciously and only when clearly relevant to the user's request. Specifically:
1. Only use a tool if the user has explicitly requested or provided information that warrants its use. Do not make assumptions about files or data existing without the user mentioning them.
2. If there is ambiguity about whether using a tool is appropriate, ask a clarifying question to the user before proceeding. Confirm your understanding of their request and intent.
3. Prioritize providing direct responses and explanations based on your own knowledge and understanding. Use tools to supplement and enhance your responses when clearly applicable, but not as a default action.`
type XMLTools struct {
XMLName struct{} `xml:"tools"`
ToolDescriptions []XMLToolDescription `xml:"tool_description"`
}
type XMLToolDescription struct {
ToolName string `xml:"tool_name"`
Description string `xml:"description"`
Parameters []XMLToolParameter `xml:"parameters>parameter"`
}
type XMLToolParameter struct {
Name string `xml:"name"`
Type string `xml:"type"`
Description string `xml:"description"`
}
type XMLFunctionCalls struct {
XMLName struct{} `xml:"function_calls"`
Invoke []XMLFunctionInvoke `xml:"invoke"`
}
type XMLFunctionInvoke struct {
ToolName string `xml:"tool_name"`
Parameters XMLFunctionInvokeParameters `xml:"parameters"`
}
type XMLFunctionInvokeParameters struct {
String string `xml:",innerxml"`
}
type XMLFunctionResults struct {
XMLName struct{} `xml:"function_results"`
Result []XMLFunctionResult `xml:"result"`
}
type XMLFunctionResult struct {
ToolName string `xml:"tool_name"`
Stdout string `xml:"stdout"`
}
// accepts raw XML from XMLFunctionInvokeParameters.String, returns map of
// parameters name to value
func parseFunctionParametersXML(params string) map[string]interface{} {
lines := strings.Split(params, "\n")
ret := make(map[string]interface{}, len(lines))
for _, line := range lines {
i := strings.Index(line, ">")
if i == -1 {
continue
}
j := strings.Index(line, "</")
if j == -1 {
continue
}
// chop from after opening < to first > to get parameter name,
// then chop after > to first </ to get parameter value
ret[line[1:i]] = line[i+1 : j]
}
return ret
}
func convertToolsToXMLTools(tools []api.ToolSpec) XMLTools {
converted := make([]XMLToolDescription, len(tools))
for i, tool := range tools {
converted[i].ToolName = tool.Name
converted[i].Description = tool.Description
params := make([]XMLToolParameter, len(tool.Parameters))
for j, param := range tool.Parameters {
params[j].Name = param.Name
params[j].Description = param.Description
params[j].Type = param.Type
}
converted[i].Parameters = params
}
return XMLTools{
ToolDescriptions: converted,
}
}
func convertXMLFunctionCallsToToolCalls(functionCalls XMLFunctionCalls) []api.ToolCall {
toolCalls := make([]api.ToolCall, len(functionCalls.Invoke))
for i, invoke := range functionCalls.Invoke {
toolCalls[i].Name = invoke.ToolName
toolCalls[i].Parameters = parseFunctionParametersXML(invoke.Parameters.String)
}
return toolCalls
}
func convertToolCallsToXMLFunctionCalls(toolCalls []api.ToolCall) XMLFunctionCalls {
converted := make([]XMLFunctionInvoke, len(toolCalls))
for i, toolCall := range toolCalls {
var params XMLFunctionInvokeParameters
var paramXML string
for key, value := range toolCall.Parameters {
paramXML += fmt.Sprintf("<%s>%v</%s>\n", key, value, key)
}
params.String = paramXML
converted[i] = XMLFunctionInvoke{
ToolName: toolCall.Name,
Parameters: params,
}
}
return XMLFunctionCalls{
Invoke: converted,
}
}
func convertToolResultsToXMLFunctionResult(toolResults []api.ToolResult) XMLFunctionResults {
converted := make([]XMLFunctionResult, len(toolResults))
for i, result := range toolResults {
converted[i].ToolName = result.ToolName
converted[i].Stdout = result.Result
}
return XMLFunctionResults{
Result: converted,
}
}
func buildToolsSystemPrompt(tools []api.ToolSpec) string {
xmlTools := convertToolsToXMLTools(tools)
xmlToolsString, err := xmlTools.XMLString()
if err != nil {
panic("Could not serialize []api.Tool to XMLTools")
}
return TOOL_PREAMBLE + "\n\n" + xmlToolsString + "\n\n" + TOOL_PREAMBLE_FOOTER
}
func (x XMLTools) XMLString() (string, error) {
tmpl, err := template.New("tools").Parse(`<tools>
{{range .ToolDescriptions}}<tool_description>
<tool_name>{{.ToolName}}</tool_name>
<description>
{{.Description}}
</description>
<parameters>
{{range .Parameters}}<parameter>
<name>{{.Name}}</name>
<type>{{.Type}}</type>
<description>{{.Description}}</description>
</parameter>
{{end}}</parameters>
</tool_description>
{{end}}</tools>`)
if err != nil {
return "", err
}
var buf bytes.Buffer
if err := tmpl.Execute(&buf, x); err != nil {
return "", err
}
return buf.String(), nil
}
func (x XMLFunctionResults) XMLString() (string, error) {
tmpl, err := template.New("function_results").Parse(`<function_results>
{{range .Result}}<result>
<tool_name>{{.ToolName}}</tool_name>
<stdout>{{.Stdout}}</stdout>
</result>
{{end}}</function_results>`)
if err != nil {
return "", err
}
var buf bytes.Buffer
if err := tmpl.Execute(&buf, x); err != nil {
return "", err
}
return buf.String(), nil
}
func (x XMLFunctionCalls) XMLString() (string, error) {
tmpl, err := template.New("function_calls").Parse(`<function_calls>
{{range .Invoke}}<invoke>
<tool_name>{{.ToolName}}</tool_name>
<parameters>{{.Parameters.String}}</parameters>
</invoke>
{{end}}</function_calls>`)
if err != nil {
return "", err
}
var buf bytes.Buffer
if err := tmpl.Execute(&buf, x); err != nil {
return "", err
}
return buf.String(), nil
}